/*
 * Decompiled with CFR 0.152.
 */
package dmonner.xlbp.layer;

import dmonner.xlbp.Component;
import dmonner.xlbp.NetworkCopier;
import dmonner.xlbp.NetworkStringBuilder;
import dmonner.xlbp.Responsibilities;
import dmonner.xlbp.layer.AbstractFanInLayer;
import dmonner.xlbp.util.MatrixTools;
import java.util.Arrays;

public class PiLayer
extends AbstractFanInLayer {
    private static final long serialVersionUID = 1L;
    private float[][] prod;
    private float[][] buf;

    public PiLayer(PiLayer that, NetworkCopier copier) {
        super(that, copier);
    }

    public PiLayer(String name, int size) {
        super(name, size);
    }

    @Override
    public void activateTest() {
        System.arraycopy(this.upstream[0].getActivations(), 0, this.y, 0, this.size);
        for (int k = 1; k < this.nUpstream; ++k) {
            MatrixTools.multiply(this.upstream[k].getActivations(), this.y, this.size);
        }
    }

    @Override
    public void activateTrain() {
        this.activateTest();
        this.updateProd();
    }

    @Override
    public void build() {
        if (!this.built) {
            super.build();
            this.y = new float[this.size];
            this.d = new Responsibilities(this.size);
            this.buf = new float[this.nUpstream][this.size];
            this.prod = new float[this.nUpstream][this.size];
            this.built = true;
        }
    }

    @Override
    public PiLayer copy(NetworkCopier copier) {
        return new PiLayer(this, copier);
    }

    @Override
    public PiLayer copy(String nameSuffix) {
        return this.copy(new NetworkCopier(nameSuffix));
    }

    @Override
    public void copyConnectivityFrom(Component comp, NetworkCopier copier) {
        super.copyConnectivityFrom(comp, copier);
        if (comp instanceof PiLayer) {
            PiLayer that = (PiLayer)comp;
            if (that.prod != null && that.buf != null) {
                if (copier.copyState() && this.nUpstream == that.nUpstream) {
                    this.prod = MatrixTools.copy(that.prod);
                    this.buf = MatrixTools.copy(that.buf);
                } else {
                    this.prod = new float[this.nUpstream][this.size];
                    this.buf = new float[this.nUpstream][this.size];
                }
            }
        }
    }

    @Override
    public void toString(NetworkStringBuilder sb) {
        super.toString(sb);
        sb.pushIndent();
        if (sb.showExtra()) {
            sb.appendln("Prod:");
            sb.pushIndent();
            sb.appendln(MatrixTools.toString(this.prod));
            sb.popIndent();
        }
        sb.popIndent();
    }

    @Override
    public void updateEligibilities() {
        if (this.downstreamCopyLayer != null) {
            this.downstream.updateUpstreamResponsibilities(this.myIndexInDownstream);
        }
    }

    private void updateProd() {
        int j;
        int k;
        int n = this.nUpstream;
        Arrays.fill(this.buf[0], 1.0f);
        for (k = 1; k < n; ++k) {
            float[] ykm1 = this.upstream[k - 1].getActivations();
            float[] bk = this.buf[k];
            float[] bkm1 = this.buf[k - 1];
            for (j = 0; j < this.size; ++j) {
                bk[j] = bkm1[j] * ykm1[j];
            }
        }
        Arrays.fill(this.prod[n - 1], 1.0f);
        for (k = n - 2; k >= 0; --k) {
            float[] ykp1 = this.upstream[k + 1].getActivations();
            float[] pk = this.prod[k];
            float[] pkp1 = this.prod[k + 1];
            for (j = 0; j < this.size; ++j) {
                pk[j] = pkp1[j] * ykp1[j];
            }
        }
        MatrixTools.multiplyElementwise(this.buf, this.prod, n, this.size);
    }

    @Override
    public void updateResponsibilities() {
        if (this.downstreamCopyLayer == null) {
            this.downstream.updateUpstreamResponsibilities(this.myIndexInDownstream);
        }
    }

    @Override
    public void updateUpstreamResponsibilities(int index) {
        this.upstream[index].getResponsibilities(this.myIndexInUpstream[index]).copyMul(this.d, this.prod[index]);
    }
}

