/*
 * Decompiled with CFR 0.152.
 */
package dmonner.xlbp.connection;

import dmonner.xlbp.NetworkCopier;
import dmonner.xlbp.NetworkStringBuilder;
import dmonner.xlbp.Responsibilities;
import dmonner.xlbp.WeightInitializer;
import dmonner.xlbp.WeightUpdater;
import dmonner.xlbp.WeightUpdaterType;
import dmonner.xlbp.connection.LayerConnection;
import dmonner.xlbp.layer.WeightReceiverLayer;
import dmonner.xlbp.layer.WeightSenderLayer;
import dmonner.xlbp.util.MatrixTools;
import java.util.Arrays;

public class AdjacencyListConnection
extends LayerConnection {
    private static final long serialVersionUID = 1L;
    private WeightInitializer win;
    private WeightUpdaterType wut;
    private WeightUpdater updater;
    private float[][] w;
    private float[][] e;
    private float[] in;
    private int[][] c;
    private int[][] r;
    private int[] n;
    private int nw;
    private boolean cleared;
    private boolean overwrite;

    public AdjacencyListConnection(AdjacencyListConnection that, NetworkCopier copier) {
        super(that, copier);
        this.win = that.win;
        this.wut = that.wut;
        this.updater = that.updater != null ? this.wut.make(this) : null;
        this.cleared = copier.copyState() ? that.cleared : true;
        boolean bl = this.overwrite = copier.copyState() ? that.overwrite : true;
        if (that.built) {
            this.in = copier.copyState() ? MatrixTools.copy(that.in) : MatrixTools.empty(that.in);
            if (copier.copyWeights()) {
                this.w = MatrixTools.copy(that.w);
                this.e = MatrixTools.copy(that.e);
                this.c = MatrixTools.copy(that.c);
                this.n = MatrixTools.copy(that.n);
                this.nw = that.nw;
            } else {
                this.initializeAlphas(this.updater);
                this.initializeWeights(this.win);
            }
        }
    }

    public AdjacencyListConnection(String name, WeightReceiverLayer to, WeightSenderLayer from) {
        super(name, to, from);
        this.cleared = true;
        this.overwrite = true;
    }

    public AdjacencyListConnection(WeightReceiverLayer to, WeightSenderLayer from) {
        super(to, from);
        this.cleared = true;
        this.overwrite = true;
    }

    @Override
    public void activateTest() {
        int toSize = this.to.size();
        float[] y = this.to.getActivations();
        float[] x = this.from.getActivations();
        for (int j = 0; j < toSize; ++j) {
            float sum = 0.0f;
            float[] wj = this.w[j];
            int[] cj = this.c[j];
            int nj = this.n[j];
            for (int i = 0; i < nj; ++i) {
                float input = x[cj[i]];
                sum += wj[i] * input;
            }
            y[j] = sum;
        }
    }

    @Override
    public void activateTrain() {
        this.cleared = false;
        System.arraycopy(this.from.getActivations(), 0, this.in, 0, this.from.size());
        this.activateTest();
    }

    @Override
    public void build() {
        if (!this.built) {
            super.build();
            if (this.win == null) {
                throw new IllegalStateException("Missing a WeightInitializer in " + this.name);
            }
            if (this.wut == null) {
                throw new IllegalStateException("Missing a WeightUpdaterType in " + this.name);
            }
            this.in = new float[this.from.size()];
            this.updater = this.wut.make(this);
            this.initializeWeights(this.win);
            this.initializeAlphas(this.updater);
            this.built = true;
        }
    }

    @Override
    public void clear() {
        this.cleared = true;
        this.overwrite = true;
    }

    @Override
    public AdjacencyListConnection copy(NetworkCopier copier) {
        return new AdjacencyListConnection(this, copier);
    }

    @Override
    public float[] getCachedInput() {
        return this.in;
    }

    @Override
    public float getWeight(int j, int i) {
        return this.w[j][this.r[j][i]];
    }

    @Override
    public void initializeAlphas(WeightUpdater wu) {
        wu.initialize(this.to.size(), this.from.size());
    }

    @Override
    public void initializeWeights(WeightInitializer wi) {
        this.nw = 0;
        int toSize = this.to.size();
        int fromSize = this.from.size();
        this.w = new float[toSize][fromSize];
        this.e = new float[toSize][fromSize];
        this.c = new int[toSize][fromSize];
        this.r = new int[toSize][fromSize];
        this.n = new int[toSize];
        for (int j = 0; j < toSize; ++j) {
            int nj = 0;
            int[] cj = this.c[j];
            int[] rj = this.r[j];
            float[] wj = this.w[j];
            for (int i = 0; i < fromSize; ++i) {
                if (!wi.newWeight(j, i)) continue;
                cj[nj] = i;
                rj[i] = nj;
                wj[nj] = wi.randomWeight(j, i);
                ++nj;
            }
            this.n[j] = nj;
            this.nw += nj;
        }
    }

    @Override
    public int nWeights() {
        return this.nw;
    }

    @Override
    public void processBatch() {
        this.updater.processBatch();
    }

    @Override
    public void setWeightInitializer(WeightInitializer win) {
        this.win = win;
    }

    @Override
    public void setWeightUpdater(WeightUpdaterType wut) {
        this.wut = wut;
    }

    @Override
    public float[][] toEligibilitiesMatrix() {
        float[][] m = new float[this.to.size()][this.from.size()];
        if (!this.cleared) {
            for (int j = 0; j < this.to.size(); ++j) {
                for (int k = 0; k < this.n[j]; ++k) {
                    int i = this.c[j][k];
                    m[j][i] = this.to.getDownstreamCopyLayer() != null ? this.e[j][k] : this.in[i];
                }
            }
        }
        return m;
    }

    @Override
    public float[][] toMatrix() {
        float[][] m = new float[this.to.size()][this.from.size()];
        for (int i = 0; i < this.to.size(); ++i) {
            for (int k = 0; k < this.n[i]; ++k) {
                int j = this.c[i][k];
                m[i][j] = this.w[i][k];
            }
        }
        return m;
    }

    @Override
    public void toString(NetworkStringBuilder sb) {
        if (sb.showName()) {
            sb.indent();
            sb.append(this.name);
            sb.append(" : ");
            sb.append(this.getClass().getSimpleName());
            sb.appendln();
        }
        sb.pushIndent();
        if (sb.showExtra()) {
            sb.appendln("Cached Inputs:");
            sb.pushIndent();
            if (this.cleared) {
                sb.appendln("Empty");
            } else {
                sb.appendln(MatrixTools.toString(this.in));
            }
            sb.popIndent();
        }
        if (sb.showWeights()) {
            sb.appendln("Weights:");
            sb.pushIndent();
            sb.appendln(MatrixTools.toString(this.w));
            sb.popIndent();
        }
        if (sb.showEligibilities()) {
            sb.appendln("Eligibilities:");
            sb.pushIndent();
            if (this.overwrite) {
                sb.appendln("Empty");
            } else {
                sb.appendln(MatrixTools.toString(this.e));
            }
            sb.popIndent();
        }
        this.updater.toString(sb);
        sb.popIndent();
    }

    @Override
    public void updateEligibilities(Responsibilities resp, Responsibilities prev) {
        int toSize = this.to.size();
        float[] d = resp.get();
        if (this.overwrite) {
            for (int j = 0; j < toSize; ++j) {
                float[] ej = this.e[j];
                float dj = d[j];
                int nj = this.n[j];
                int[] cj = this.c[j];
                for (int i = 0; i < nj; ++i) {
                    ej[i] = dj * this.in[cj[i]];
                }
            }
            this.overwrite = false;
        } else {
            float[] p = prev.get();
            for (int j = 0; j < toSize; ++j) {
                float[] ej = this.e[j];
                float dj = d[j];
                float pj = p[j];
                int nj = this.n[j];
                int[] cj = this.c[j];
                for (int i = 0; i < nj; ++i) {
                    ej[i] = ej[i] * pj + dj * this.in[cj[i]];
                }
            }
        }
    }

    @Override
    public void updateResponsibilities() {
        float[] toD = this.getToLayerResponsibilities();
        if (toD == null) {
            return;
        }
        float[] fromD = this.getFromLayerResponsibilities();
        Arrays.fill(fromD, 0.0f);
        int toSize = this.to.size();
        for (int k = 0; k < toSize; ++k) {
            float[] wk = this.w[k];
            int[] ck = this.c[k];
            int nk = this.n[k];
            for (int j = 0; j < nk; ++j) {
                int n = ck[j];
                fromD[n] = fromD[n] + wk[j] * toD[k];
            }
        }
    }

    @Override
    public void updateWeights(float[][] dw) {
        int toSize = this.to.size();
        for (int j = 0; j < toSize; ++j) {
            float[] wj = this.w[j];
            float[] dwj = dw[j];
            int nj = this.n[j];
            for (int i = 0; i < nj; ++i) {
                int n = i;
                wj[n] = wj[n] + dwj[i];
            }
        }
    }

    @Override
    public void updateWeightsFromEligibilities(Responsibilities copyresp) {
        int toSize = this.to.size();
        float[] d = copyresp.get();
        this.updater.updateFromEligibilities(this.e, d);
        for (int j = 0; j < toSize; ++j) {
            float[] ej = this.e[j];
            float dj = d[j];
            float[] wj = this.w[j];
            int nj = this.n[j];
            for (int i = 0; i < nj; ++i) {
                int n = i;
                wj[n] = wj[n] + this.updater.getUpdate(j, i, ej[i] * dj);
            }
        }
    }

    @Override
    public void updateWeightsFromInputs(Responsibilities resp) {
        int toSize = this.to.size();
        float[] d = resp.get();
        this.updater.updateFromInputs(this.in, d);
        for (int j = 0; j < toSize; ++j) {
            float dj = d[j];
            float[] wj = this.w[j];
            int nj = this.n[j];
            int[] cj = this.c[j];
            for (int i = 0; i < nj; ++i) {
                int n = i;
                wj[n] = wj[n] + this.updater.getUpdate(j, i, this.in[cj[i]] * dj);
            }
        }
    }
}

