/*
 * Decompiled with CFR 0.152.
 */
package dmonner.xlbp.compound;

import dmonner.xlbp.Component;
import dmonner.xlbp.NetworkCopier;
import dmonner.xlbp.NetworkStringBuilder;
import dmonner.xlbp.WeightInitializer;
import dmonner.xlbp.WeightUpdaterType;
import dmonner.xlbp.compound.AbstractWeightedCompound;
import dmonner.xlbp.compound.WeightBank;
import dmonner.xlbp.layer.BiasLayer;
import dmonner.xlbp.layer.FanOutLayer;
import dmonner.xlbp.layer.FunctionLayer;
import dmonner.xlbp.layer.RepulsionLayer;
import dmonner.xlbp.layer.SigmaLayer;
import dmonner.xlbp.layer.UpstreamLayer;
import dmonner.xlbp.util.MatrixTools;
import java.util.ArrayList;

public class FunctionCompound
extends AbstractWeightedCompound {
    private static final long serialVersionUID = 1L;
    private final int size;
    private BiasLayer bias;
    private SigmaLayer net;
    private final FunctionLayer act;
    private RepulsionLayer repel;
    private FanOutLayer fan;
    private Component[] activate;

    public FunctionCompound(FunctionCompound that, NetworkCopier copier) {
        super(that, copier);
        this.size = that.size;
        this.bias = copier.getCopyOf(that.bias);
        this.net = copier.getCopyOf(that.net);
        this.act = copier.getCopyOf(that.act);
        this.repel = copier.getCopyOf(that.repel);
        this.fan = copier.getCopyOf(that.fan);
        this.activate = new Component[that.activate.length];
        for (int i = 0; i < that.activate.length; ++i) {
            this.activate[i] = copier.getCopyOf(that.activate[i]);
        }
        this.in = copier.getCopyOf(that.in);
        this.out = copier.getCopyOf(that.out);
    }

    public FunctionCompound(String name, FunctionLayer act) {
        this(name, act, true);
    }

    public FunctionCompound(String name, FunctionLayer act, boolean biases) {
        super(name);
        this.size = act.size();
        this.bias = biases ? new BiasLayer(name + "Biases", this.size) : null;
        this.net = new SigmaLayer(name + "Net", this.size);
        this.act = act;
        this.fan = new FanOutLayer(name + "Fanout", this.size);
        this.fan.addUpstream(act);
        act.addUpstream(this.net);
        if (this.bias != null) {
            this.net.addUpstream(this.bias);
        }
        this.activate = this.bias != null ? new Component[]{this.bias, this.net, act, this.fan} : new Component[]{this.net, act, this.fan};
        this.in = this.net;
        this.out = this.fan;
    }

    @Override
    public void activateTest() {
        super.activateTest();
        for (int i = 0; i < this.activate.length; ++i) {
            this.activate[i].activateTest();
        }
    }

    @Override
    public void activateTrain() {
        super.activateTrain();
        for (int i = 0; i < this.activate.length; ++i) {
            this.activate[i].activateTrain();
        }
    }

    public void addRepulsion(float retain, float amount) {
        this.repel = new RepulsionLayer(this.name + "Repel", this.size, retain, amount);
        this.fan.removeUpstream(0);
        this.fan.addUpstream(this.repel);
        this.repel.addUpstream(this.act);
        this.activate = this.bias != null ? new Component[]{this.bias, this.net, this.act, this.repel, this.fan} : new Component[]{this.net, this.act, this.repel, this.fan};
    }

    @Override
    public void build() {
        if (!this.built) {
            super.build();
            for (Component component : this.activate) {
                component.build();
            }
            this.built = true;
        }
    }

    @Override
    public void clearActivations() {
        super.clearActivations();
        for (int i = 0; i < this.activate.length; ++i) {
            this.activate[i].clearActivations();
        }
    }

    @Override
    public void clearEligibilities() {
        super.clearEligibilities();
        for (int i = 0; i < this.activate.length; ++i) {
            this.activate[i].clearEligibilities();
        }
    }

    @Override
    public void clearResponsibilities() {
        super.clearResponsibilities();
        for (int i = 0; i < this.activate.length; ++i) {
            this.activate[i].clearResponsibilities();
        }
    }

    @Override
    public FunctionCompound copy(NetworkCopier copier) {
        return new FunctionCompound(this, copier);
    }

    @Override
    public FunctionCompound copy(String nameSuffix) {
        NetworkCopier copier = new NetworkCopier(nameSuffix);
        FunctionCompound copy = this.copy(copier);
        copier.build();
        return copy;
    }

    public FunctionLayer getActLayer() {
        return this.act;
    }

    public BiasLayer getBiasInput() {
        return this.bias;
    }

    @Override
    public Component[] getComponents() {
        return (Component[])this.activate.clone();
    }

    public SigmaLayer getNetLayer() {
        return this.net;
    }

    @Override
    public int nWeights() {
        int n = super.nWeights();
        if (this.bias != null) {
            n += this.bias.nWeights();
        }
        return n;
    }

    @Override
    public boolean optimize() {
        if (!super.optimize()) {
            return false;
        }
        ArrayList<UpstreamLayer> activate = new ArrayList<UpstreamLayer>(5);
        if (this.bias != null) {
            if (this.bias.optimize()) {
                activate.add(this.bias);
            } else {
                this.bias = null;
            }
        }
        if (this.net != null) {
            if (this.net.optimize()) {
                activate.add(this.net);
            } else {
                this.net = null;
            }
        }
        if (this.act == null || !this.act.optimize()) {
            throw new IllegalStateException("Optimized out the activation FunctionLayer in " + this.name);
        }
        activate.add(this.act);
        if (this.repel != null) {
            if (this.repel.optimize()) {
                activate.add(this.repel);
            } else {
                this.repel = null;
            }
        }
        if (this.fan != null) {
            if (this.fan.optimize()) {
                activate.add(this.fan);
            } else {
                this.fan = null;
            }
        }
        this.activate = activate.toArray(new Component[activate.size()]);
        this.in = this.net == null ? this.act : this.net;
        this.out = (UpstreamLayer)this.activate[this.activate.length - 1];
        return true;
    }

    @Override
    public void setWeightInitializer(WeightInitializer win) {
        super.setWeightInitializer(win);
        if (this.bias != null) {
            this.bias.setWeightInitializer(win);
        }
    }

    @Override
    public void setWeightUpdaterType(WeightUpdaterType wut) {
        super.setWeightUpdaterType(wut);
        if (this.bias != null) {
            this.bias.setWeightUpdaterType(wut);
        }
    }

    public int size() {
        return this.size;
    }

    @Override
    public void toString(NetworkStringBuilder sb) {
        if (sb.showIntermediate()) {
            super.toString(sb);
            sb.pushIndent();
            for (int i = this.activate.length - 1; i >= 0; --i) {
                this.activate[i].toString(sb);
            }
            for (WeightBank bank : this.conn) {
                bank.toString(sb);
            }
            sb.popIndent();
        } else {
            super.toString(sb);
            sb.pushIndent();
            if (sb.showActivations()) {
                sb.appendln("Activations:");
                sb.pushIndent();
                sb.appendln(MatrixTools.toString(this.act.getActivations()));
                sb.popIndent();
            }
            if (sb.showStates() && this.net != null) {
                sb.appendln("States:");
                sb.pushIndent();
                sb.appendln(MatrixTools.toString(this.net.getActivations()));
                sb.popIndent();
            }
            if (sb.showResponsibilities()) {
                sb.appendln("Responsibilities:");
                sb.pushIndent();
                sb.appendln(this.net.getResponsibilities().toString());
                sb.popIndent();
            }
            this.bias.getConnection().toString(sb);
            for (WeightBank bank : this.conn) {
                bank.getConnection().toString(sb);
            }
            sb.popIndent();
        }
    }

    @Override
    public void unbuild() {
        super.unbuild();
        for (Component component : this.activate) {
            component.unbuild();
        }
    }

    @Override
    public void updateEligibilities() {
        for (int i = this.activate.length - 1; i >= 0; --i) {
            this.activate[i].updateEligibilities();
        }
        super.updateEligibilities();
    }

    @Override
    public void updateResponsibilities() {
        for (int i = this.activate.length - 1; i >= 0; --i) {
            this.activate[i].updateResponsibilities();
        }
        super.updateResponsibilities();
    }

    @Override
    public void updateWeights() {
        for (int i = this.activate.length - 1; i >= 0; --i) {
            this.activate[i].updateWeights();
        }
        super.updateWeights();
    }
}

