package dmonner.xlbp;

import dmonner.xlbp.connection.Connection;
import dmonner.xlbp.util.MatrixTools;

/* loaded from: input_file:dmonner/xlbp/MomentumWeightUpdater.class */
public class MomentumWeightUpdater implements WeightUpdater {
    public static final long serialVersionUID = 1;
    private final Connection parent;
    private float[][] trace;
    private final float a;
    private final float m;
    private int to;
    private int from;

    public MomentumWeightUpdater(Connection connection) {
        this(connection, 0.1f, 0.9f);
    }

    public MomentumWeightUpdater(Connection connection, float f, float f2) {
        this.parent = connection;
        this.a = f;
        this.m = f2;
    }

    @Override // dmonner.xlbp.WeightUpdater
    public Connection getConnection() {
        return this.parent;
    }

    @Override // dmonner.xlbp.WeightUpdater
    public float getUpdate(int i, float f) {
        return this.a * this.trace[0][i];
    }

    @Override // dmonner.xlbp.WeightUpdater
    public float getUpdate(int i, int i2, float f) {
        return this.a * this.trace[i][i2];
    }

    @Override // dmonner.xlbp.WeightUpdater
    public void initialize(int i) {
        this.to = 1;
        this.from = i;
        this.trace = new float[this.to][this.from];
    }

    @Override // dmonner.xlbp.WeightUpdater
    public void initialize(int i, int i2) {
        this.to = i;
        this.from = i2;
        this.trace = new float[i][i2];
    }

    @Override // dmonner.xlbp.WeightUpdater
    public void processBatch() {
    }

    @Override // dmonner.xlbp.WeightUpdater
    public void toString(NetworkStringBuilder networkStringBuilder) {
        if (networkStringBuilder.showLearningRates()) {
            networkStringBuilder.appendln("Learning Rates: " + this.a);
        }
        if (networkStringBuilder.showExtra()) {
            networkStringBuilder.appendln("Traces:");
            networkStringBuilder.pushIndent();
            networkStringBuilder.appendln(MatrixTools.toString(this.trace));
            networkStringBuilder.popIndent();
        }
    }

    @Override // dmonner.xlbp.WeightUpdater
    public void updateFromBiases(float[] fArr) {
        float[] fArr2 = this.trace[0];
        for (int i = 0; i < this.from; i++) {
            updateSingle(fArr2, fArr[i], i);
        }
    }

    @Override // dmonner.xlbp.WeightUpdater
    public void updateFromEligibilities(float[][] fArr, float[] fArr2) {
        for (int i = 0; i < this.to; i++) {
            float[] fArr3 = fArr[i];
            float f = fArr2[i];
            float[] fArr4 = this.trace[i];
            for (int i2 = 0; i2 < this.from; i2++) {
                updateSingle(fArr4, fArr3[i2] * f, i2);
            }
        }
    }

    @Override // dmonner.xlbp.WeightUpdater
    public void updateFromInputs(float[] fArr, float[] fArr2) {
        for (int i = 0; i < this.to; i++) {
            float f = fArr2[i];
            float[] fArr3 = this.trace[i];
            for (int i2 = 0; i2 < this.from; i2++) {
                updateSingle(fArr3, fArr[i2] * f, i2);
            }
        }
    }

    @Override // dmonner.xlbp.WeightUpdater
    public void updateFromVector(float[] fArr, float[] fArr2) {
        float[] fArr3 = this.trace[0];
        for (int i = 0; i < this.from; i++) {
            updateSingle(fArr3, fArr[i] * fArr2[i], i);
        }
    }

    private void updateSingle(float[] fArr, float f, int i) {
        fArr[i] = (fArr[i] * this.m) + f;
    }
}
