package dmonner.xlbp.connection;

import dmonner.xlbp.NetworkCopier;
import dmonner.xlbp.NetworkStringBuilder;
import dmonner.xlbp.Responsibilities;
import dmonner.xlbp.WeightInitializer;
import dmonner.xlbp.WeightUpdater;
import dmonner.xlbp.WeightUpdaterType;
import dmonner.xlbp.layer.WeightReceiverLayer;
import dmonner.xlbp.layer.WeightSenderLayer;
import dmonner.xlbp.util.MatrixTools;
import java.util.Arrays;

/* loaded from: input_file:dmonner/xlbp/connection/AdjacencyListConnection.class */
public class AdjacencyListConnection extends LayerConnection {
    private static final long serialVersionUID = 1;
    private WeightInitializer win;
    private WeightUpdaterType wut;
    private WeightUpdater updater;
    private float[][] w;
    private float[][] e;
    private float[] in;
    private int[][] c;
    private int[][] r;
    private int[] n;
    private int nw;
    private boolean cleared;
    private boolean overwrite;

    public AdjacencyListConnection(AdjacencyListConnection adjacencyListConnection, NetworkCopier networkCopier) {
        super(adjacencyListConnection, networkCopier);
        this.win = adjacencyListConnection.win;
        this.wut = adjacencyListConnection.wut;
        this.updater = adjacencyListConnection.updater != null ? this.wut.make(this) : null;
        this.cleared = networkCopier.copyState() ? adjacencyListConnection.cleared : true;
        this.overwrite = networkCopier.copyState() ? adjacencyListConnection.overwrite : true;
        if (adjacencyListConnection.built) {
            if (networkCopier.copyState()) {
                this.in = MatrixTools.copy(adjacencyListConnection.in);
            } else {
                this.in = MatrixTools.empty(adjacencyListConnection.in);
            }
            if (!networkCopier.copyWeights()) {
                initializeAlphas(this.updater);
                initializeWeights(this.win);
                return;
            }
            this.w = MatrixTools.copy(adjacencyListConnection.w);
            this.e = MatrixTools.copy(adjacencyListConnection.e);
            this.c = MatrixTools.copy(adjacencyListConnection.c);
            this.n = MatrixTools.copy(adjacencyListConnection.n);
            this.nw = adjacencyListConnection.nw;
        }
    }

    public AdjacencyListConnection(String str, WeightReceiverLayer weightReceiverLayer, WeightSenderLayer weightSenderLayer) {
        super(str, weightReceiverLayer, weightSenderLayer);
        this.cleared = true;
        this.overwrite = true;
    }

    public AdjacencyListConnection(WeightReceiverLayer weightReceiverLayer, WeightSenderLayer weightSenderLayer) {
        super(weightReceiverLayer, weightSenderLayer);
        this.cleared = true;
        this.overwrite = true;
    }

    @Override // dmonner.xlbp.connection.Connection
    public void activateTest() {
        int size = this.to.size();
        float[] activations = this.to.getActivations();
        float[] activations2 = this.from.getActivations();
        for (int i = 0; i < size; i++) {
            float f = 0.0f;
            float[] fArr = this.w[i];
            int[] iArr = this.c[i];
            int i2 = this.n[i];
            for (int i3 = 0; i3 < i2; i3++) {
                f += fArr[i3] * activations2[iArr[i3]];
            }
            activations[i] = f;
        }
    }

    @Override // dmonner.xlbp.connection.Connection
    public void activateTrain() {
        this.cleared = false;
        System.arraycopy(this.from.getActivations(), 0, this.in, 0, this.from.size());
        activateTest();
    }

    @Override // dmonner.xlbp.connection.LayerConnection, dmonner.xlbp.connection.Connection
    public void build() {
        if (this.built) {
            return;
        }
        super.build();
        if (this.win == null) {
            throw new IllegalStateException("Missing a WeightInitializer in " + this.name);
        }
        if (this.wut == null) {
            throw new IllegalStateException("Missing a WeightUpdaterType in " + this.name);
        }
        this.in = new float[this.from.size()];
        this.updater = this.wut.make(this);
        initializeWeights(this.win);
        initializeAlphas(this.updater);
        this.built = true;
    }

    @Override // dmonner.xlbp.connection.Connection
    public void clear() {
        this.cleared = true;
        this.overwrite = true;
    }

    @Override // dmonner.xlbp.connection.LayerConnection, dmonner.xlbp.connection.Connection
    public AdjacencyListConnection copy(NetworkCopier networkCopier) {
        return new AdjacencyListConnection(this, networkCopier);
    }

    @Override // dmonner.xlbp.connection.LayerConnection
    public float[] getCachedInput() {
        return this.in;
    }

    @Override // dmonner.xlbp.connection.Connection
    public float getWeight(int i, int i2) {
        return this.w[i][this.r[i][i2]];
    }

    @Override // dmonner.xlbp.connection.Connection
    public void initializeAlphas(WeightUpdater weightUpdater) {
        weightUpdater.initialize(this.to.size(), this.from.size());
    }

    @Override // dmonner.xlbp.connection.Connection
    public void initializeWeights(WeightInitializer weightInitializer) {
        this.nw = 0;
        int size = this.to.size();
        int size2 = this.from.size();
        this.w = new float[size][size2];
        this.e = new float[size][size2];
        this.c = new int[size][size2];
        this.r = new int[size][size2];
        this.n = new int[size];
        for (int i = 0; i < size; i++) {
            int i2 = 0;
            int[] iArr = this.c[i];
            int[] iArr2 = this.r[i];
            float[] fArr = this.w[i];
            for (int i3 = 0; i3 < size2; i3++) {
                if (weightInitializer.newWeight(i, i3)) {
                    iArr[i2] = i3;
                    iArr2[i3] = i2;
                    fArr[i2] = weightInitializer.randomWeight(i, i3);
                    i2++;
                }
            }
            this.n[i] = i2;
            this.nw += i2;
        }
    }

    @Override // dmonner.xlbp.connection.Connection
    public int nWeights() {
        return this.nw;
    }

    @Override // dmonner.xlbp.connection.Connection
    public void processBatch() {
        this.updater.processBatch();
    }

    @Override // dmonner.xlbp.connection.Connection
    public void setWeightInitializer(WeightInitializer weightInitializer) {
        this.win = weightInitializer;
    }

    @Override // dmonner.xlbp.connection.Connection
    public void setWeightUpdater(WeightUpdaterType weightUpdaterType) {
        this.wut = weightUpdaterType;
    }

    @Override // dmonner.xlbp.connection.Connection
    public float[][] toEligibilitiesMatrix() {
        float[][] fArr = new float[this.to.size()][this.from.size()];
        if (!this.cleared) {
            for (int i = 0; i < this.to.size(); i++) {
                for (int i2 = 0; i2 < this.n[i]; i2++) {
                    int i3 = this.c[i][i2];
                    fArr[i][i3] = this.to.getDownstreamCopyLayer() != null ? this.e[i][i2] : this.in[i3];
                }
            }
        }
        return fArr;
    }

    @Override // dmonner.xlbp.connection.Connection
    public float[][] toMatrix() {
        float[][] fArr = new float[this.to.size()][this.from.size()];
        for (int i = 0; i < this.to.size(); i++) {
            for (int i2 = 0; i2 < this.n[i]; i2++) {
                fArr[i][this.c[i][i2]] = this.w[i][i2];
            }
        }
        return fArr;
    }

    @Override // dmonner.xlbp.connection.Connection
    public void toString(NetworkStringBuilder networkStringBuilder) {
        if (networkStringBuilder.showName()) {
            networkStringBuilder.indent();
            networkStringBuilder.append(this.name);
            networkStringBuilder.append(" : ");
            networkStringBuilder.append(getClass().getSimpleName());
            networkStringBuilder.appendln();
        }
        networkStringBuilder.pushIndent();
        if (networkStringBuilder.showExtra()) {
            networkStringBuilder.appendln("Cached Inputs:");
            networkStringBuilder.pushIndent();
            if (this.cleared) {
                networkStringBuilder.appendln("Empty");
            } else {
                networkStringBuilder.appendln(MatrixTools.toString(this.in));
            }
            networkStringBuilder.popIndent();
        }
        if (networkStringBuilder.showWeights()) {
            networkStringBuilder.appendln("Weights:");
            networkStringBuilder.pushIndent();
            networkStringBuilder.appendln(MatrixTools.toString(this.w));
            networkStringBuilder.popIndent();
        }
        if (networkStringBuilder.showEligibilities()) {
            networkStringBuilder.appendln("Eligibilities:");
            networkStringBuilder.pushIndent();
            if (this.overwrite) {
                networkStringBuilder.appendln("Empty");
            } else {
                networkStringBuilder.appendln(MatrixTools.toString(this.e));
            }
            networkStringBuilder.popIndent();
        }
        this.updater.toString(networkStringBuilder);
        networkStringBuilder.popIndent();
    }

    @Override // dmonner.xlbp.connection.Connection
    public void updateEligibilities(Responsibilities responsibilities, Responsibilities responsibilities2) {
        int size = this.to.size();
        float[] fArr = responsibilities.get();
        if (this.overwrite) {
            for (int i = 0; i < size; i++) {
                float[] fArr2 = this.e[i];
                float f = fArr[i];
                int i2 = this.n[i];
                int[] iArr = this.c[i];
                for (int i3 = 0; i3 < i2; i3++) {
                    fArr2[i3] = f * this.in[iArr[i3]];
                }
            }
            this.overwrite = false;
            return;
        }
        float[] fArr3 = responsibilities2.get();
        for (int i4 = 0; i4 < size; i4++) {
            float[] fArr4 = this.e[i4];
            float f2 = fArr[i4];
            float f3 = fArr3[i4];
            int i5 = this.n[i4];
            int[] iArr2 = this.c[i4];
            for (int i6 = 0; i6 < i5; i6++) {
                fArr4[i6] = (fArr4[i6] * f3) + (f2 * this.in[iArr2[i6]]);
            }
        }
    }

    @Override // dmonner.xlbp.connection.LayerConnection
    public void updateResponsibilities() {
        float[] toLayerResponsibilities = getToLayerResponsibilities();
        if (toLayerResponsibilities == null) {
            return;
        }
        float[] fromLayerResponsibilities = getFromLayerResponsibilities();
        Arrays.fill(fromLayerResponsibilities, 0.0f);
        int size = this.to.size();
        for (int i = 0; i < size; i++) {
            float[] fArr = this.w[i];
            int[] iArr = this.c[i];
            int i2 = this.n[i];
            for (int i3 = 0; i3 < i2; i3++) {
                int i4 = iArr[i3];
                fromLayerResponsibilities[i4] = fromLayerResponsibilities[i4] + (fArr[i3] * toLayerResponsibilities[i]);
            }
        }
    }

    @Override // dmonner.xlbp.connection.Connection
    public void updateWeights(float[][] fArr) {
        int size = this.to.size();
        for (int i = 0; i < size; i++) {
            float[] fArr2 = this.w[i];
            float[] fArr3 = fArr[i];
            int i2 = this.n[i];
            for (int i3 = 0; i3 < i2; i3++) {
                int i4 = i3;
                fArr2[i4] = fArr2[i4] + fArr3[i3];
            }
        }
    }

    @Override // dmonner.xlbp.connection.Connection
    public void updateWeightsFromEligibilities(Responsibilities responsibilities) {
        int size = this.to.size();
        float[] fArr = responsibilities.get();
        this.updater.updateFromEligibilities(this.e, fArr);
        for (int i = 0; i < size; i++) {
            float[] fArr2 = this.e[i];
            float f = fArr[i];
            float[] fArr3 = this.w[i];
            int i2 = this.n[i];
            for (int i3 = 0; i3 < i2; i3++) {
                int i4 = i3;
                fArr3[i4] = fArr3[i4] + this.updater.getUpdate(i, i3, fArr2[i3] * f);
            }
        }
    }

    @Override // dmonner.xlbp.connection.Connection
    public void updateWeightsFromInputs(Responsibilities responsibilities) {
        int size = this.to.size();
        float[] fArr = responsibilities.get();
        this.updater.updateFromInputs(this.in, fArr);
        for (int i = 0; i < size; i++) {
            float f = fArr[i];
            float[] fArr2 = this.w[i];
            int i2 = this.n[i];
            int[] iArr = this.c[i];
            for (int i3 = 0; i3 < i2; i3++) {
                int i4 = i3;
                fArr2[i4] = fArr2[i4] + this.updater.getUpdate(i, i3, this.in[iArr[i3]] * f);
            }
        }
    }
}
