/*
 * Decompiled with CFR 0.152.
 */
package dmonner.xlbp;

import dmonner.xlbp.NetworkStringBuilder;
import dmonner.xlbp.WeightUpdater;
import dmonner.xlbp.connection.Connection;
import dmonner.xlbp.util.MatrixTools;
import java.util.Arrays;

public class ResilientWeightUpdater
implements WeightUpdater {
    public static final long serialVersionUID = 1L;
    private final float eta_plus = 1.2f;
    private final float eta_minus = 0.5f;
    private final float a_max = 100.0f;
    private final float a_min = 1.0E-10f;
    private final float a_init = 0.01f;
    private final Connection parent;
    private float[][] pdw;
    private float[][] a;
    private float[][] dw;
    private float[][] wc;
    private int to;
    private int from;

    public ResilientWeightUpdater(Connection parent) {
        this.parent = parent;
    }

    public float[][] get() {
        return this.a;
    }

    @Override
    public Connection getConnection() {
        return this.parent;
    }

    @Override
    public float getUpdate(int i, float dw) {
        return 0.0f;
    }

    @Override
    public float getUpdate(int j, int i, float dw) {
        return 0.0f;
    }

    @Override
    public void initialize(int size) {
        this.initialize(1, size);
    }

    @Override
    public void initialize(int to, int from) {
        this.to = to;
        this.from = from;
        this.pdw = new float[to][from];
        this.dw = new float[to][from];
        this.a = new float[to][from];
        this.wc = new float[to][from];
        for (int j = 0; j < to; ++j) {
            Arrays.fill(this.a[j], 0.01f);
        }
    }

    @Override
    public void processBatch() {
        int j;
        for (j = 0; j < this.to; ++j) {
            float[] dwj = this.dw[j];
            float[] pdwj = this.pdw[j];
            float[] aj = this.a[j];
            float[] wcj = this.wc[j];
            for (int i = 0; i < this.from; ++i) {
                float aji = aj[i];
                float dwji = dwj[i];
                float prod = pdwj[i] * dwji;
                if (prod > 0.0f) {
                    aj[i] = Math.min(aj[i] * 1.2f, 100.0f);
                    pdwj[i] = dwji;
                    wcj[i] = aji * this.sign(dwji);
                    continue;
                }
                if (prod < 0.0f) {
                    aj[i] = Math.max(aj[i] * 0.5f, 1.0E-10f);
                    pdwj[i] = 0.0f;
                    wcj[i] = 0.0f;
                    continue;
                }
                pdwj[i] = dwji;
                wcj[i] = aji * this.sign(dwji);
            }
        }
        this.parent.updateWeights(this.wc);
        for (j = 0; j < this.dw.length; ++j) {
            Arrays.fill(this.dw[j], 0.0f);
        }
    }

    private float sign(float x) {
        if (x > 0.0f) {
            return 1.0f;
        }
        if (x < 0.0f) {
            return -1.0f;
        }
        return 0.0f;
    }

    @Override
    public void toString(NetworkStringBuilder sb) {
        if (sb.showLearningRates()) {
            sb.appendln("Learning Rates:");
            sb.pushIndent();
            sb.appendln(MatrixTools.toString(this.a));
            sb.popIndent();
        }
        if (sb.showExtra()) {
            sb.appendln("Previous Weight Deltas:");
            sb.pushIndent();
            sb.appendln(MatrixTools.toString(this.pdw));
            sb.popIndent();
            sb.appendln("Weight Deltas:");
            sb.pushIndent();
            sb.appendln(MatrixTools.toString(this.dw));
            sb.popIndent();
        }
    }

    @Override
    public void updateFromBiases(float[] d) {
        float[] dwj = this.dw[0];
        for (int i = 0; i < this.from; ++i) {
            int n = i;
            dwj[n] = dwj[n] + d[i];
        }
    }

    @Override
    public void updateFromEligibilities(float[][] e, float[] d) {
        for (int j = 0; j < this.to; ++j) {
            float[] ej = e[j];
            float dj = d[j];
            float[] dwj = this.dw[j];
            for (int i = 0; i < this.from; ++i) {
                int n = i;
                dwj[n] = dwj[n] + ej[i] * dj;
            }
        }
    }

    @Override
    public void updateFromInputs(float[] in, float[] d) {
        for (int j = 0; j < this.to; ++j) {
            float dj = d[j];
            float[] dwj = this.dw[j];
            for (int i = 0; i < this.from; ++i) {
                int n = i;
                dwj[n] = dwj[n] + in[i] * dj;
            }
        }
    }

    @Override
    public void updateFromVector(float[] v, float[] d) {
        float[] dwj = this.dw[0];
        for (int i = 0; i < this.from; ++i) {
            int n = i;
            dwj[n] = dwj[n] + v[i] * d[i];
        }
    }
}

