/*
 * Decompiled with CFR 0.152.
 */
package dmonner.xlbp;

import dmonner.xlbp.NetworkStringBuilder;
import dmonner.xlbp.WeightUpdater;
import dmonner.xlbp.connection.Connection;
import dmonner.xlbp.util.MatrixTools;

public class MomentumWeightUpdater
implements WeightUpdater {
    public static final long serialVersionUID = 1L;
    private final Connection parent;
    private float[][] trace;
    private final float a;
    private final float m;
    private int to;
    private int from;

    public MomentumWeightUpdater(Connection parent) {
        this(parent, 0.1f, 0.9f);
    }

    public MomentumWeightUpdater(Connection parent, float a, float m) {
        this.parent = parent;
        this.a = a;
        this.m = m;
    }

    @Override
    public Connection getConnection() {
        return this.parent;
    }

    @Override
    public float getUpdate(int i, float dw) {
        return this.a * this.trace[0][i];
    }

    @Override
    public float getUpdate(int j, int i, float dw) {
        return this.a * this.trace[j][i];
    }

    @Override
    public void initialize(int size) {
        this.to = 1;
        this.from = size;
        this.trace = new float[this.to][this.from];
    }

    @Override
    public void initialize(int to, int from) {
        this.to = to;
        this.from = from;
        this.trace = new float[to][from];
    }

    @Override
    public void processBatch() {
    }

    @Override
    public void toString(NetworkStringBuilder sb) {
        if (sb.showLearningRates()) {
            sb.appendln("Learning Rates: " + this.a);
        }
        if (sb.showExtra()) {
            sb.appendln("Traces:");
            sb.pushIndent();
            sb.appendln(MatrixTools.toString(this.trace));
            sb.popIndent();
        }
    }

    @Override
    public void updateFromBiases(float[] d) {
        float[] trj = this.trace[0];
        for (int i = 0; i < this.from; ++i) {
            this.updateSingle(trj, d[i], i);
        }
    }

    @Override
    public void updateFromEligibilities(float[][] e, float[] d) {
        for (int j = 0; j < this.to; ++j) {
            float[] ej = e[j];
            float dj = d[j];
            float[] trj = this.trace[j];
            for (int i = 0; i < this.from; ++i) {
                this.updateSingle(trj, ej[i] * dj, i);
            }
        }
    }

    @Override
    public void updateFromInputs(float[] in, float[] d) {
        for (int j = 0; j < this.to; ++j) {
            float dj = d[j];
            float[] trj = this.trace[j];
            for (int i = 0; i < this.from; ++i) {
                this.updateSingle(trj, in[i] * dj, i);
            }
        }
    }

    @Override
    public void updateFromVector(float[] v, float[] d) {
        float[] trj = this.trace[0];
        for (int i = 0; i < this.from; ++i) {
            this.updateSingle(trj, v[i] * d[i], i);
        }
    }

    private void updateSingle(float[] trj, float dwji, int i) {
        trj[i] = trj[i] * this.m + dwji;
    }
}

