/*
 * Decompiled with CFR 0.152.
 */
package dmonner.xlbp.connection;

import dmonner.xlbp.NetworkCopier;
import dmonner.xlbp.NetworkStringBuilder;
import dmonner.xlbp.Responsibilities;
import dmonner.xlbp.UniformWeightInitializer;
import dmonner.xlbp.WeightInitializer;
import dmonner.xlbp.WeightUpdater;
import dmonner.xlbp.WeightUpdaterType;
import dmonner.xlbp.connection.LayerConnection;
import dmonner.xlbp.layer.WeightReceiverLayer;
import dmonner.xlbp.layer.WeightSenderLayer;
import dmonner.xlbp.util.MatrixTools;

public class DiagonalConnection
extends LayerConnection {
    private static final long serialVersionUID = 1L;
    private WeightInitializer win;
    private WeightUpdaterType wut;
    private WeightUpdater updater;
    private float[] w;
    private float[] e;
    private float[] in;
    private boolean cleared;
    private boolean overwrite;
    private boolean fullOnly;

    public DiagonalConnection(DiagonalConnection that, NetworkCopier copier) {
        super(that, copier);
        this.win = that.win;
        this.wut = that.wut;
        this.updater = that.updater != null ? this.wut.make(this) : null;
        this.cleared = copier.copyState() ? that.cleared : true;
        this.overwrite = copier.copyState() ? that.overwrite : true;
        boolean bl = this.fullOnly = copier.copyState() ? that.fullOnly : true;
        if (that.built) {
            this.in = copier.copyState() ? MatrixTools.copy(that.in) : MatrixTools.empty(that.in);
            if (copier.copyWeights()) {
                this.w = MatrixTools.copy(that.w);
                this.e = MatrixTools.copy(that.e);
            } else {
                this.initializeAlphas(this.updater);
                this.initializeWeights(this.win);
            }
        }
    }

    public DiagonalConnection(String name, WeightReceiverLayer to, WeightSenderLayer from) {
        super(name, to, from);
        if (from.size() != to.size()) {
            throw new IllegalArgumentException("Sending and receiving layers of a DiagonalConnection must be the same size: " + from.size() + " != " + to.size());
        }
        this.win = new UniformWeightInitializer();
        this.cleared = true;
        this.overwrite = true;
        this.fullOnly = true;
    }

    public DiagonalConnection(WeightReceiverLayer to, WeightSenderLayer from) {
        this(from.getName() + "DiagonalTo" + to.getName(), to, from);
        this.cleared = true;
        this.fullOnly = true;
    }

    @Override
    public void activateTest() {
        int toSize = this.to.size();
        float[] y = this.to.getActivations();
        float[] x = this.from.getActivations();
        for (int j = 0; j < toSize; ++j) {
            y[j] = this.w[j] * x[j];
        }
    }

    @Override
    public void activateTrain() {
        this.cleared = false;
        System.arraycopy(this.from.getActivations(), 0, this.in, 0, this.from.size());
        this.activateTest();
    }

    @Override
    public void build() {
        if (!this.built) {
            super.build();
            if (this.win == null) {
                throw new IllegalStateException("Missing a WeightInitializer in " + this.name);
            }
            if (this.wut == null) {
                throw new IllegalStateException("Missing a WeightUpdaterType in " + this.name);
            }
            this.in = new float[this.from.size()];
            this.updater = this.wut.make(this);
            this.initializeWeights(this.win);
            this.initializeAlphas(this.updater);
            this.built = true;
        }
    }

    @Override
    public void clear() {
        this.cleared = true;
        this.overwrite = true;
    }

    @Override
    public DiagonalConnection copy(NetworkCopier copier) {
        return new DiagonalConnection(this, copier);
    }

    public float[] get() {
        return this.w;
    }

    @Override
    public float[] getCachedInput() {
        return this.in;
    }

    @Override
    public float getWeight(int j, int i) {
        return this.w[j];
    }

    @Override
    public void initializeAlphas(WeightUpdater lrs) {
        lrs.initialize(this.to.size());
    }

    @Override
    public void initializeWeights(WeightInitializer wi) {
        int toSize = this.to.size();
        this.w = new float[toSize];
        this.e = new float[toSize];
        for (int j = 0; j < toSize; ++j) {
            this.w[j] = wi.randomWeight(j, j);
        }
    }

    @Override
    public int nWeights() {
        return this.to.size();
    }

    @Override
    public void processBatch() {
        this.updater.processBatch();
    }

    public void set(float[] w) {
        if (w.length != this.to.size()) {
            throw new IllegalArgumentException("Incompatible number of weights: " + this.to.size() + " != " + w.length);
        }
        System.arraycopy(w, 0, this.w, 0, w.length);
    }

    public void setFullOnly(boolean fullOnly) {
        this.fullOnly = fullOnly;
    }

    @Override
    public void setWeightInitializer(WeightInitializer win) {
        if (this.fullOnly && !win.fullConnectivity()) {
            System.out.println("WARNING: Cannot use a DiagonalConnection with anything less than full connectivity; ignoring new WeightInitializer.");
        } else {
            this.win = win;
        }
    }

    @Override
    public void setWeightUpdater(WeightUpdaterType wut) {
        this.wut = wut;
    }

    @Override
    public float[][] toEligibilitiesMatrix() {
        float[][] m = new float[this.to.size()][this.to.size()];
        if (!this.cleared) {
            for (int i = 0; i < this.to.size(); ++i) {
                m[i][i] = this.to.getDownstreamCopyLayer() != null ? this.e[i] : this.in[i];
            }
        }
        return m;
    }

    @Override
    public float[][] toMatrix() {
        float[][] m = new float[this.to.size()][this.to.size()];
        for (int i = 0; i < this.to.size(); ++i) {
            m[i][i] = this.w[i];
        }
        return m;
    }

    @Override
    public void toString(NetworkStringBuilder sb) {
        if (sb.showName()) {
            sb.indent();
            sb.append(this.name);
            sb.append(" : ");
            sb.append(this.getClass().getSimpleName());
            sb.appendln();
        }
        sb.pushIndent();
        if (sb.showExtra()) {
            sb.appendln("Cached Inputs:");
            sb.pushIndent();
            if (this.cleared) {
                sb.appendln("Empty");
            } else {
                sb.appendln(MatrixTools.toString(this.in));
            }
            sb.popIndent();
        }
        if (sb.showWeights()) {
            sb.appendln("Weights:");
            sb.pushIndent();
            sb.appendln(MatrixTools.toString(this.w));
            sb.popIndent();
        }
        if (sb.showEligibilities()) {
            sb.appendln("Eligibilities:");
            sb.pushIndent();
            if (this.overwrite) {
                sb.appendln("Empty");
            } else {
                sb.appendln(MatrixTools.toString(this.e));
            }
            sb.popIndent();
        }
        this.updater.toString(sb);
        sb.popIndent();
    }

    @Override
    public void updateEligibilities(Responsibilities resp, Responsibilities prev) {
        int toSize = this.to.size();
        float[] d = resp.get();
        if (this.overwrite) {
            for (int j = 0; j < toSize; ++j) {
                this.e[j] = d[j] * this.in[j];
            }
            this.overwrite = false;
        } else {
            float[] p = prev.get();
            for (int j = 0; j < toSize; ++j) {
                this.e[j] = this.e[j] * p[j] + d[j] * this.in[j];
            }
        }
    }

    @Override
    public void updateResponsibilities() {
        float[] toD = this.getToLayerResponsibilities();
        if (toD == null) {
            return;
        }
        float[] fromD = this.getFromLayerResponsibilities();
        int toSize = this.to.size();
        for (int j = 0; j < toSize; ++j) {
            fromD[j] = this.w[j] * toD[j];
        }
    }

    @Override
    public void updateWeights(float[][] dw) {
        int toSize = this.to.size();
        float[] dwj = dw[0];
        for (int i = 0; i < toSize; ++i) {
            int n = i;
            this.w[n] = this.w[n] + dwj[i];
        }
    }

    @Override
    public void updateWeightsFromEligibilities(Responsibilities copyresp) {
        int toSize = this.to.size();
        float[] d = copyresp.get();
        this.updater.updateFromVector(this.e, d);
        for (int j = 0; j < toSize; ++j) {
            int n = j;
            this.w[n] = this.w[n] + this.updater.getUpdate(j, this.e[j] * d[j]);
        }
    }

    @Override
    public void updateWeightsFromInputs(Responsibilities resp) {
        int toSize = this.to.size();
        float[] d = resp.get();
        this.updater.updateFromVector(this.in, d);
        for (int j = 0; j < toSize; ++j) {
            int n = j;
            this.w[n] = this.w[n] + this.updater.getUpdate(j, this.in[j] * d[j]);
        }
    }
}

