/*
 * Decompiled with CFR 0.152.
 */
package dmonner.xlbp.connection;

import dmonner.xlbp.NetworkCopier;
import dmonner.xlbp.NetworkStringBuilder;
import dmonner.xlbp.Responsibilities;
import dmonner.xlbp.WeightInitializer;
import dmonner.xlbp.WeightUpdater;
import dmonner.xlbp.WeightUpdaterType;
import dmonner.xlbp.connection.Connection;
import dmonner.xlbp.layer.BiasLayer;
import dmonner.xlbp.util.MatrixTools;

public class BiasConnection
implements Connection {
    private static final long serialVersionUID = 1L;
    private final String name;
    private final BiasLayer to;
    private WeightInitializer win;
    private WeightUpdaterType wut;
    private WeightUpdater updater;
    private float[] w;
    private float[] e;
    private boolean cleared;
    private boolean overwrite;
    private boolean built;

    public BiasConnection(BiasConnection that, NetworkCopier copier) {
        this.name = copier.getCopyNameFrom(that);
        this.to = copier.getCopyOf(that.to);
        this.win = that.win;
        this.wut = that.wut;
        this.updater = that.updater != null ? this.wut.make(this) : null;
        this.built = that.built;
        this.cleared = copier.copyState() ? that.cleared : true;
        this.overwrite = copier.copyState() ? that.overwrite : true;
        this.w = copier.copyWeights() ? MatrixTools.copy(that.w) : MatrixTools.empty(that.w);
        this.e = copier.copyWeights() ? MatrixTools.copy(that.e) : MatrixTools.empty(that.e);
    }

    public BiasConnection(BiasLayer to) {
        this(to.getName(), to);
    }

    public BiasConnection(String name, BiasLayer to) {
        this.name = name;
        this.to = to;
        this.cleared = true;
        this.overwrite = true;
    }

    @Override
    public void activateTest() {
    }

    @Override
    public void activateTrain() {
        this.cleared = false;
    }

    public void alias(float[] w) {
        if (w.length != this.to.size()) {
            throw new IllegalArgumentException("Incompatible number of weights: " + this.to.size() + " != " + w.length);
        }
        this.w = w;
    }

    @Override
    public void build() {
        if (!this.built) {
            if (this.win == null) {
                throw new IllegalStateException("Missing a WeightInitializer in " + this.name);
            }
            if (this.wut == null) {
                throw new IllegalStateException("Missing a WeightUpdaterType in " + this.name);
            }
            this.updater = this.wut.make(this);
            this.initializeWeights(this.win);
            this.initializeAlphas(this.updater);
            this.built = true;
        }
    }

    @Override
    public void clear() {
        this.cleared = true;
        this.overwrite = true;
    }

    @Override
    public BiasConnection copy(NetworkCopier copier) {
        return new BiasConnection(this, copier);
    }

    public float[] get() {
        return this.w;
    }

    @Override
    public String getName() {
        return this.name;
    }

    @Override
    public float getWeight(int j, int i) {
        return this.w[j];
    }

    @Override
    public void initializeAlphas(WeightUpdater lrs) {
        lrs.initialize(this.to.size());
    }

    @Override
    public void initializeWeights(WeightInitializer wi) {
        int toSize = this.to.size();
        this.w = new float[toSize];
        this.e = new float[toSize];
        for (int j = 0; j < toSize; ++j) {
            this.w[j] = wi.randomWeight(j, j);
        }
    }

    @Override
    public int nWeights() {
        return this.to.size();
    }

    @Override
    public int nWeightsPossible() {
        return this.to.size();
    }

    @Override
    public void processBatch() {
        this.updater.processBatch();
    }

    public void set(float[] w) {
        if (w.length != this.to.size()) {
            throw new IllegalArgumentException("Incompatible number of weights: " + this.to.size() + " != " + w.length);
        }
        System.arraycopy(w, 0, this.w, 0, w.length);
    }

    @Override
    public void setWeightInitializer(WeightInitializer win) {
        this.win = win;
    }

    @Override
    public void setWeightUpdater(WeightUpdaterType wut) {
        this.wut = wut;
    }

    @Override
    public float[][] toEligibilitiesMatrix() {
        float[][] m = new float[this.to.size()][this.to.size()];
        if (!this.cleared) {
            for (int i = 0; i < this.to.size(); ++i) {
                m[i][i] = this.to.getDownstreamCopyLayer() != null ? this.e[i] : 1.0f;
            }
        }
        return m;
    }

    @Override
    public float[][] toMatrix() {
        float[][] m = new float[this.to.size()][this.to.size()];
        for (int i = 0; i < this.to.size(); ++i) {
            m[i][i] = this.w[i];
        }
        return m;
    }

    @Override
    public void toString(NetworkStringBuilder sb) {
        if (sb.showName()) {
            sb.indent();
            sb.append(this.name);
            sb.append(" : ");
            sb.append(this.getClass().getSimpleName());
            sb.appendln();
        }
        sb.pushIndent();
        if (sb.showWeights()) {
            sb.appendln("Biases:");
            sb.pushIndent();
            sb.appendln(MatrixTools.toString(this.w));
            sb.popIndent();
        }
        if (sb.showEligibilities()) {
            sb.appendln("Eligibilities:");
            sb.pushIndent();
            if (this.overwrite) {
                sb.appendln("Empty");
            } else {
                sb.appendln(MatrixTools.toString(this.e));
            }
            sb.popIndent();
        }
        this.updater.toString(sb);
        sb.popIndent();
    }

    @Override
    public String toString(String show) {
        NetworkStringBuilder sb = new NetworkStringBuilder(show);
        this.toString(sb);
        return sb.toString();
    }

    @Override
    public void unbuild() {
        this.built = false;
    }

    @Override
    public void updateEligibilities(Responsibilities resp, Responsibilities prev) {
        int toSize = this.to.size();
        float[] d = resp.get();
        if (this.overwrite) {
            System.arraycopy(d, 0, this.e, 0, toSize);
            this.overwrite = false;
        } else {
            float[] p = prev.get();
            for (int j = 0; j < toSize; ++j) {
                this.e[j] = this.e[j] * p[j] + d[j];
            }
        }
    }

    @Override
    public void updateWeights(float[][] dw) {
        int toSize = this.to.size();
        float[] dwj = dw[0];
        for (int i = 0; i < toSize; ++i) {
            int n = i;
            this.w[n] = this.w[n] + dwj[i];
        }
    }

    @Override
    public void updateWeightsFromEligibilities(Responsibilities copyresp) {
        int toSize = this.to.size();
        float[] d = copyresp.get();
        this.updater.updateFromVector(this.e, d);
        for (int j = 0; j < toSize; ++j) {
            int n = j;
            this.w[n] = this.w[n] + this.updater.getUpdate(j, this.e[j] * d[j]);
        }
    }

    @Override
    public void updateWeightsFromInputs(Responsibilities resp) {
        int toSize = this.to.size();
        float[] d = resp.get();
        this.updater.updateFromBiases(d);
        for (int j = 0; j < toSize; ++j) {
            int n = j;
            this.w[n] = this.w[n] + this.updater.getUpdate(j, d[j]);
        }
    }
}

