package dmonner.xlbp.connection;

import dmonner.xlbp.NetworkCopier;
import dmonner.xlbp.NetworkStringBuilder;
import dmonner.xlbp.Responsibilities;
import dmonner.xlbp.WeightInitializer;
import dmonner.xlbp.WeightUpdater;
import dmonner.xlbp.WeightUpdaterType;
import dmonner.xlbp.layer.BiasLayer;
import dmonner.xlbp.util.MatrixTools;

/* loaded from: input_file:dmonner/xlbp/connection/BiasConnection.class */
public class BiasConnection implements Connection {
    private static final long serialVersionUID = 1;
    private final String name;
    private final BiasLayer to;
    private WeightInitializer win;
    private WeightUpdaterType wut;
    private WeightUpdater updater;
    private float[] w;
    private float[] e;
    private boolean cleared;
    private boolean overwrite;
    private boolean built;

    public BiasConnection(BiasConnection biasConnection, NetworkCopier networkCopier) {
        this.name = networkCopier.getCopyNameFrom(biasConnection);
        this.to = (BiasLayer) networkCopier.getCopyOf(biasConnection.to);
        this.win = biasConnection.win;
        this.wut = biasConnection.wut;
        this.updater = biasConnection.updater != null ? this.wut.make(this) : null;
        this.built = biasConnection.built;
        this.cleared = networkCopier.copyState() ? biasConnection.cleared : true;
        this.overwrite = networkCopier.copyState() ? biasConnection.overwrite : true;
        this.w = networkCopier.copyWeights() ? MatrixTools.copy(biasConnection.w) : MatrixTools.empty(biasConnection.w);
        this.e = networkCopier.copyWeights() ? MatrixTools.copy(biasConnection.e) : MatrixTools.empty(biasConnection.e);
    }

    public BiasConnection(BiasLayer biasLayer) {
        this(biasLayer.getName(), biasLayer);
    }

    public BiasConnection(String str, BiasLayer biasLayer) {
        this.name = str;
        this.to = biasLayer;
        this.cleared = true;
        this.overwrite = true;
    }

    @Override // dmonner.xlbp.connection.Connection
    public void activateTest() {
    }

    @Override // dmonner.xlbp.connection.Connection
    public void activateTrain() {
        this.cleared = false;
    }

    public void alias(float[] fArr) {
        if (fArr.length != this.to.size()) {
            throw new IllegalArgumentException("Incompatible number of weights: " + this.to.size() + " != " + fArr.length);
        }
        this.w = fArr;
    }

    @Override // dmonner.xlbp.connection.Connection
    public void build() {
        if (this.built) {
            return;
        }
        if (this.win == null) {
            throw new IllegalStateException("Missing a WeightInitializer in " + this.name);
        }
        if (this.wut == null) {
            throw new IllegalStateException("Missing a WeightUpdaterType in " + this.name);
        }
        this.updater = this.wut.make(this);
        initializeWeights(this.win);
        initializeAlphas(this.updater);
        this.built = true;
    }

    @Override // dmonner.xlbp.connection.Connection
    public void clear() {
        this.cleared = true;
        this.overwrite = true;
    }

    @Override // dmonner.xlbp.connection.Connection
    public BiasConnection copy(NetworkCopier networkCopier) {
        return new BiasConnection(this, networkCopier);
    }

    public float[] get() {
        return this.w;
    }

    @Override // dmonner.xlbp.connection.Connection
    public String getName() {
        return this.name;
    }

    @Override // dmonner.xlbp.connection.Connection
    public float getWeight(int i, int i2) {
        return this.w[i];
    }

    @Override // dmonner.xlbp.connection.Connection
    public void initializeAlphas(WeightUpdater weightUpdater) {
        weightUpdater.initialize(this.to.size());
    }

    @Override // dmonner.xlbp.connection.Connection
    public void initializeWeights(WeightInitializer weightInitializer) {
        int size = this.to.size();
        this.w = new float[size];
        this.e = new float[size];
        for (int i = 0; i < size; i++) {
            this.w[i] = weightInitializer.randomWeight(i, i);
        }
    }

    @Override // dmonner.xlbp.connection.Connection
    public int nWeights() {
        return this.to.size();
    }

    @Override // dmonner.xlbp.connection.Connection
    public int nWeightsPossible() {
        return this.to.size();
    }

    @Override // dmonner.xlbp.connection.Connection
    public void processBatch() {
        this.updater.processBatch();
    }

    public void set(float[] fArr) {
        if (fArr.length != this.to.size()) {
            throw new IllegalArgumentException("Incompatible number of weights: " + this.to.size() + " != " + fArr.length);
        }
        System.arraycopy(fArr, 0, this.w, 0, fArr.length);
    }

    @Override // dmonner.xlbp.connection.Connection
    public void setWeightInitializer(WeightInitializer weightInitializer) {
        this.win = weightInitializer;
    }

    @Override // dmonner.xlbp.connection.Connection
    public void setWeightUpdater(WeightUpdaterType weightUpdaterType) {
        this.wut = weightUpdaterType;
    }

    @Override // dmonner.xlbp.connection.Connection
    public float[][] toEligibilitiesMatrix() {
        float[][] fArr = new float[this.to.size()][this.to.size()];
        if (!this.cleared) {
            for (int i = 0; i < this.to.size(); i++) {
                fArr[i][i] = this.to.getDownstreamCopyLayer() != null ? this.e[i] : 1.0f;
            }
        }
        return fArr;
    }

    @Override // dmonner.xlbp.connection.Connection
    public float[][] toMatrix() {
        float[][] fArr = new float[this.to.size()][this.to.size()];
        for (int i = 0; i < this.to.size(); i++) {
            fArr[i][i] = this.w[i];
        }
        return fArr;
    }

    @Override // dmonner.xlbp.connection.Connection
    public void toString(NetworkStringBuilder networkStringBuilder) {
        if (networkStringBuilder.showName()) {
            networkStringBuilder.indent();
            networkStringBuilder.append(this.name);
            networkStringBuilder.append(" : ");
            networkStringBuilder.append(getClass().getSimpleName());
            networkStringBuilder.appendln();
        }
        networkStringBuilder.pushIndent();
        if (networkStringBuilder.showWeights()) {
            networkStringBuilder.appendln("Biases:");
            networkStringBuilder.pushIndent();
            networkStringBuilder.appendln(MatrixTools.toString(this.w));
            networkStringBuilder.popIndent();
        }
        if (networkStringBuilder.showEligibilities()) {
            networkStringBuilder.appendln("Eligibilities:");
            networkStringBuilder.pushIndent();
            if (this.overwrite) {
                networkStringBuilder.appendln("Empty");
            } else {
                networkStringBuilder.appendln(MatrixTools.toString(this.e));
            }
            networkStringBuilder.popIndent();
        }
        this.updater.toString(networkStringBuilder);
        networkStringBuilder.popIndent();
    }

    @Override // dmonner.xlbp.connection.Connection
    public String toString(String str) {
        NetworkStringBuilder networkStringBuilder = new NetworkStringBuilder(str);
        toString(networkStringBuilder);
        return networkStringBuilder.toString();
    }

    @Override // dmonner.xlbp.connection.Connection
    public void unbuild() {
        this.built = false;
    }

    @Override // dmonner.xlbp.connection.Connection
    public void updateEligibilities(Responsibilities responsibilities, Responsibilities responsibilities2) {
        int size = this.to.size();
        float[] fArr = responsibilities.get();
        if (this.overwrite) {
            System.arraycopy(fArr, 0, this.e, 0, size);
            this.overwrite = false;
            return;
        }
        float[] fArr2 = responsibilities2.get();
        for (int i = 0; i < size; i++) {
            this.e[i] = (this.e[i] * fArr2[i]) + fArr[i];
        }
    }

    @Override // dmonner.xlbp.connection.Connection
    public void updateWeights(float[][] fArr) {
        int size = this.to.size();
        float[] fArr2 = fArr[0];
        for (int i = 0; i < size; i++) {
            float[] fArr3 = this.w;
            int i2 = i;
            fArr3[i2] = fArr3[i2] + fArr2[i];
        }
    }

    @Override // dmonner.xlbp.connection.Connection
    public void updateWeightsFromEligibilities(Responsibilities responsibilities) {
        int size = this.to.size();
        float[] fArr = responsibilities.get();
        this.updater.updateFromVector(this.e, fArr);
        for (int i = 0; i < size; i++) {
            float[] fArr2 = this.w;
            int i2 = i;
            fArr2[i2] = fArr2[i2] + this.updater.getUpdate(i, this.e[i] * fArr[i]);
        }
    }

    @Override // dmonner.xlbp.connection.Connection
    public void updateWeightsFromInputs(Responsibilities responsibilities) {
        int size = this.to.size();
        float[] fArr = responsibilities.get();
        this.updater.updateFromBiases(fArr);
        for (int i = 0; i < size; i++) {
            float[] fArr2 = this.w;
            int i2 = i;
            fArr2[i2] = fArr2[i2] + this.updater.getUpdate(i, fArr[i]);
        }
    }
}
