/*
 * Decompiled with CFR 0.152.
 */
package dmonner.xlbp.compound;

import dmonner.xlbp.Component;
import dmonner.xlbp.DownstreamComponent;
import dmonner.xlbp.Function;
import dmonner.xlbp.InternalComponent;
import dmonner.xlbp.NetworkCopier;
import dmonner.xlbp.NetworkStringBuilder;
import dmonner.xlbp.UpstreamComponent;
import dmonner.xlbp.WeightInitializer;
import dmonner.xlbp.WeightUpdaterType;
import dmonner.xlbp.compound.AbstractWeightedCompound;
import dmonner.xlbp.compound.DiagonalWeightBank;
import dmonner.xlbp.compound.FunctionCompound;
import dmonner.xlbp.compound.LinearCompound;
import dmonner.xlbp.compound.SingletonCompound;
import dmonner.xlbp.compound.WeightBank;
import dmonner.xlbp.compound.WeightedCompound;
import dmonner.xlbp.layer.AbstractUpstreamLayer;
import dmonner.xlbp.layer.CopyDestinationLayer;
import dmonner.xlbp.layer.CopySourceLayer;
import dmonner.xlbp.layer.FanOutLayer;
import dmonner.xlbp.layer.FunctionLayer;
import dmonner.xlbp.layer.Layer;
import dmonner.xlbp.layer.LogisticLayer;
import dmonner.xlbp.layer.PiLayer;
import dmonner.xlbp.layer.SigmaLayer;
import java.util.ArrayList;
import java.util.Iterator;

public class MemoryCellCompound
extends AbstractWeightedCompound {
    private static final long serialVersionUID = 1L;
    private final int size;
    private Component[] activate;
    private Boolean truncateGates;
    private final SingletonCompound ug;
    private final WeightedCompound is;
    private final FunctionCompound ig;
    private final FunctionCompound fg;
    private final FunctionCompound og;
    private final FunctionCompound mc;
    private final InternalComponent mc_in_gated;
    private final InternalComponent mc_state;
    private final InternalComponent mc_state_squashed;
    private final InternalComponent mc_state_gated;

    public MemoryCellCompound(MemoryCellCompound that, NetworkCopier copier) {
        super(that, copier);
        this.size = that.size;
        this.truncateGates = that.truncateGates;
        this.ug = copier.getCopyOf(that.ug);
        this.is = copier.getCopyOf(that.is);
        this.ig = copier.getCopyOf(that.ig);
        this.fg = copier.getCopyOf(that.fg);
        this.og = copier.getCopyOf(that.og);
        this.mc = copier.getCopyOf(that.mc);
        this.mc_in_gated = copier.getCopyOf(that.mc_in_gated);
        this.mc_state = copier.getCopyOf(that.mc_state);
        this.mc_state_squashed = copier.getCopyOf(that.mc_state_squashed);
        this.mc_state_gated = copier.getCopyOf(that.mc_state_gated);
        this.activate = new Component[that.activate.length];
        for (int i = 0; i < that.activate.length; ++i) {
            this.activate[i] = copier.getCopyOf(that.activate[i]);
        }
        this.in = copier.getCopyOf(that.in);
        this.out = copier.getCopyOf(that.out);
    }

    public MemoryCellCompound(String name, int size, boolean memory, boolean inputGates, boolean forgetGates, boolean outputGates, boolean squashState, boolean squashInput) {
        this(name, size, memory, squashInput ? new LogisticLayer(name + "InAct", size) : null, inputGates ? new LogisticLayer(name + "IGAct", size) : null, forgetGates ? new LogisticLayer(name + "FGAct", size) : null, squashState ? new LogisticLayer(name + "MCAct", size) : null, outputGates ? new LogisticLayer(name + "OGAct", size) : null);
    }

    public MemoryCellCompound(String name, int size, boolean memory, FunctionCompound inputSquash, FunctionCompound inputGates, FunctionCompound forgetGates, FunctionCompound stateSquash, FunctionCompound outputGates) {
        super(name);
        PiLayer pi;
        this.size = size;
        ArrayList<UpstreamComponent> act = new ArrayList<UpstreamComponent>();
        if (inputSquash != null) {
            this.is = inputSquash;
            act.add(this.is);
        } else {
            this.is = new LinearCompound(name + "In", size, false);
            act.add(this.is);
        }
        if (inputGates != null) {
            this.ig = inputGates;
            pi = new PiLayer(name + "MCInGated", size);
            pi.addUpstream(this.is);
            pi.addUpstream(this.ig);
            this.mc_in_gated = pi;
            act.add(this.ig);
            act.add(pi);
        } else {
            this.ig = null;
            this.mc_in_gated = this.is;
        }
        if (memory) {
            AbstractUpstreamLayer mc_prev_gated;
            CopySourceLayer mc_state_copy_src = new CopySourceLayer(name + "MCStateCopier", size);
            CopyDestinationLayer mc_state_copy_dest = new CopyDestinationLayer(name + "MCPrev", mc_state_copy_src);
            SingletonCompound st = new SingletonCompound(name + "MCState", new SigmaLayer(name + "MCStateLayer", size));
            if (forgetGates != null) {
                this.fg = forgetGates;
                PiLayer pi2 = new PiLayer(name + "MCPrevGated", size);
                pi2.addUpstream(mc_state_copy_dest);
                pi2.addUpstream(this.fg);
                mc_prev_gated = pi2;
                act.add(this.fg);
                act.add(mc_state_copy_dest);
                act.add(pi2);
            } else {
                this.fg = null;
                mc_prev_gated = mc_state_copy_dest;
                act.add(mc_state_copy_dest);
            }
            st.addUpstream(this.mc_in_gated);
            st.addUpstream(mc_prev_gated);
            mc_state_copy_src.addUpstream(st);
            this.ug = st;
            this.mc_state = mc_state_copy_src;
            act.add(this.ug);
            act.add(this.mc_state);
        } else {
            SingletonCompound st = new SingletonCompound(name + "MCState", new SigmaLayer(name + "MCStateLayer", size));
            st.addUpstream(this.mc_in_gated);
            this.fg = null;
            this.ug = st;
            this.mc_state = st;
            act.add(this.mc_state);
        }
        if (stateSquash != null) {
            this.mc = stateSquash;
            this.mc.addUpstream(this.mc_state);
            this.mc_state_squashed = this.mc;
            act.add(this.mc);
        } else {
            this.mc = null;
            FanOutLayer fan = new FanOutLayer(name + "MCFan", size);
            fan.addUpstream(this.mc_state);
            this.mc_state_squashed = fan;
            act.add(fan);
        }
        if (outputGates != null) {
            this.og = outputGates;
            pi = new PiLayer(name + "MCGated", size);
            FanOutLayer fan = new FanOutLayer(name + "MCGatedFan", size);
            pi.addUpstream(this.mc_state_squashed);
            pi.addUpstream(this.og);
            fan.addUpstream(pi);
            this.mc_state_gated = fan;
            act.add(this.og);
            act.add(pi);
            act.add(fan);
        } else {
            this.og = null;
            this.mc_state_gated = this.mc_state_squashed;
        }
        this.activate = act.toArray(new Component[act.size()]);
        this.in = ((DownstreamComponent)this.activate[0]).asDownstreamLayer();
        this.out = ((UpstreamComponent)this.activate[this.activate.length - 1]).asUpstreamLayer();
    }

    public MemoryCellCompound(String name, int size, boolean memory, FunctionLayer inputSquash, FunctionLayer inputGates, FunctionLayer forgetGates, FunctionLayer stateSquash, FunctionLayer outputGates) {
        this(name, size, memory, inputSquash == null ? null : new FunctionCompound(name + "In", inputSquash), inputGates == null ? null : new FunctionCompound(name + "IG", inputGates), forgetGates == null ? null : new FunctionCompound(name + "FG", forgetGates), stateSquash == null ? null : new FunctionCompound(name + "MC", stateSquash), outputGates == null ? null : new FunctionCompound(name + "OG", outputGates));
    }

    public MemoryCellCompound(String name, int size, boolean memory, String inFcn, String igFcn, String fgFcn, String mcFcn, String ogFcn) {
        this(name, size, memory, Function.fcompound(inFcn, name + "In", size), Function.fcompound(igFcn, name + "IG", size), Function.fcompound(fgFcn, name + "FG", size), Function.fcompound(mcFcn, name + "MC", size), Function.fcompound(ogFcn, name + "OG", size));
    }

    public MemoryCellCompound(String name, int size, boolean memory, String[] fcns) {
        this(name, size, memory, Function.layer(fcns[0], name + "InAct", size), Function.layer(fcns[1], name + "IGAct", size), Function.layer(fcns[2], name + "FGAct", size), Function.layer(fcns[3], name + "MCAct", size), Function.layer(fcns[4], name + "OGAct", size));
        if (fcns.length != 5) {
            throw new IllegalArgumentException("Too many functions specified: " + fcns.length);
        }
    }

    public MemoryCellCompound(String name, int size, String type) {
        this(name, size, type, "logistic");
    }

    public MemoryCellCompound(String name, int size, String type, String fcn) {
        this(name, size, !type.contains("N"), new String[]{type.contains("S") ? fcn : "none", type.contains("I") ? fcn : "none", type.contains("F") ? fcn : "none", fcn, type.contains("O") ? fcn : "none"});
        if (type.contains("P")) {
            this.addPeepholeConnections();
        }
        if (type.contains("L") || type.contains("G")) {
            this.addGatedLateralConnections();
        }
        if (type.contains("U")) {
            this.addUngatedLateralConnections();
        }
        if (type.contains("T")) {
            this.truncateGates(true);
        }
    }

    @Override
    public void activateTest() {
        super.activateTest();
        for (int i = 0; i < this.activate.length; ++i) {
            this.activate[i].activateTest();
        }
    }

    @Override
    public void activateTrain() {
        super.activateTrain();
        for (int i = 0; i < this.activate.length; ++i) {
            this.activate[i].activateTrain();
        }
    }

    public void addGatedLateralConnections() {
        if (this.ig != null) {
            this.ig.addUpstreamWeights(this);
        }
        if (this.fg != null) {
            this.fg.addUpstreamWeights(this);
        }
        if (this.og != null) {
            this.og.addUpstreamWeights(this);
        }
    }

    public void addPeepholeConnections() {
        if (this.ig != null) {
            this.ig.addUpstreamWeights(new DiagonalWeightBank(this.name + "IGPeephole", this.mc_state_squashed.asUpstreamLayer(), this.ig.getInput(), this.win, this.wut));
        }
        if (this.fg != null) {
            this.fg.addUpstreamWeights(new DiagonalWeightBank(this.name + "FGPeephole", this.mc_state_squashed.asUpstreamLayer(), this.fg.getInput(), this.win, this.wut));
        }
        if (this.og != null) {
            this.og.addUpstreamWeights(new DiagonalWeightBank(this.name + "OGPeephole", this.mc_state_squashed.asUpstreamLayer(), this.og.getInput(), this.win, this.wut));
        }
    }

    public void addUngatedLateralConnections() {
        if (this.ig != null) {
            this.ig.addUpstreamWeights(this.mc_state_squashed);
        }
        if (this.fg != null) {
            this.fg.addUpstreamWeights(this.mc_state_squashed);
        }
        if (this.og != null) {
            this.og.addUpstreamWeights(this.mc_state_squashed);
        }
    }

    @Override
    public void addUpstream(UpstreamComponent upstream) {
        this.is.addUpstream(upstream);
        if (this.ig != null) {
            this.ig.addUpstream(upstream);
        }
        if (this.fg != null) {
            this.fg.addUpstream(upstream);
        }
        if (this.og != null) {
            this.og.addUpstream(upstream);
        }
    }

    public void addUpstreamGatedLateralConnections() {
        if (this.ig != null) {
            this.ig.addUpstreamWeights(this);
        }
        if (this.fg != null) {
            this.fg.addUpstreamWeights(this);
        }
    }

    @Override
    public void addUpstreamWeights(UpstreamComponent upstream) {
        super.addUpstreamWeights(upstream);
        if (this.ig != null) {
            this.ig.addUpstreamWeights(upstream);
        }
        if (this.fg != null) {
            this.fg.addUpstreamWeights(upstream);
        }
        if (this.og != null) {
            this.og.addUpstreamWeights(upstream);
        }
    }

    @Override
    public void build() {
        if (!this.built) {
            super.build();
            for (Component component : this.activate) {
                component.build();
            }
            this.built = true;
        }
    }

    @Override
    public void clearActivations() {
        super.clearActivations();
        for (int i = 0; i < this.activate.length; ++i) {
            this.activate[i].clearActivations();
        }
    }

    @Override
    public void clearEligibilities() {
        super.clearEligibilities();
        for (int i = 0; i < this.activate.length; ++i) {
            this.activate[i].clearEligibilities();
        }
    }

    @Override
    public void clearResponsibilities() {
        super.clearResponsibilities();
        for (int i = 0; i < this.activate.length; ++i) {
            this.activate[i].clearResponsibilities();
        }
    }

    @Override
    public MemoryCellCompound copy(NetworkCopier copier) {
        return new MemoryCellCompound(this, copier);
    }

    @Override
    public MemoryCellCompound copy(String nameSuffix) {
        NetworkCopier copier = new NetworkCopier(nameSuffix);
        MemoryCellCompound copy = this.copy(copier);
        copier.build();
        return copy;
    }

    @Override
    public void copyConnectivityFrom(Component comp, NetworkCopier copier) {
        super.copyConnectivityFrom(comp, copier);
        if (comp instanceof MemoryCellCompound) {
            MemoryCellCompound that = (MemoryCellCompound)comp;
            if (this.ig != null && that.ig != null) {
                this.ig.copyConnectivityFrom(that.ig, copier);
            }
            if (this.fg != null && that.fg != null) {
                this.fg.copyConnectivityFrom(that.fg, copier);
            }
            if (this.mc != null && that.mc != null) {
                this.mc.copyConnectivityFrom(that.mc, copier);
            }
            if (this.og != null && that.og != null) {
                this.og.copyConnectivityFrom(that.og, copier);
            }
        }
    }

    @Override
    public Component[] getComponents() {
        return (Component[])this.activate.clone();
    }

    public FunctionCompound getForgetGates() {
        return this.fg;
    }

    public Layer getGatedInput() {
        return this.mc_in_gated.asUpstreamLayer();
    }

    public FunctionCompound getInputGates() {
        return this.ig;
    }

    public FunctionCompound getMemoryCells() {
        return this.mc;
    }

    public Layer getNetInputLayer() {
        if (this.is instanceof FunctionCompound) {
            return ((FunctionCompound)this.is).getActLayer();
        }
        return ((SingletonCompound)this.is).getLayer();
    }

    public FunctionCompound getOutputGates() {
        return this.og;
    }

    public Layer getStateLayer() {
        if (this.mc_state instanceof CopySourceLayer) {
            return (CopySourceLayer)this.mc_state;
        }
        return ((SingletonCompound)this.mc_state).getLayer();
    }

    public SingletonCompound getUngatedInput() {
        return this.ug;
    }

    public InternalComponent getUngatedOutput() {
        return this.mc_state_squashed;
    }

    @Override
    public int nWeights() {
        int sum = super.nWeights();
        sum += this.is.nWeights();
        if (this.mc != null) {
            sum += this.mc.nWeights();
        }
        if (this.ug != null) {
            sum += this.ug.nWeights();
        }
        if (this.ig != null) {
            sum += this.ig.nWeights();
        }
        if (this.fg != null) {
            sum += this.fg.nWeights();
        }
        if (this.og != null) {
            sum += this.og.nWeights();
        }
        return sum;
    }

    @Override
    public boolean optimize() {
        if (!super.optimize()) {
            return false;
        }
        ArrayList<Component> act = new ArrayList<Component>();
        for (Component comp : this.activate) {
            act.add(comp);
        }
        Iterator it = act.iterator();
        while (it.hasNext()) {
            if (((Component)it.next()).optimize()) continue;
            it.remove();
        }
        this.activate = act.toArray(new Component[act.size()]);
        this.in = ((DownstreamComponent)this.activate[0]).asDownstreamLayer();
        this.out = ((UpstreamComponent)this.activate[this.activate.length - 1]).asUpstreamLayer();
        return true;
    }

    @Override
    public void processBatch() {
        super.processBatch();
        if (this.is != null) {
            this.is.processBatch();
        }
        if (this.ig != null) {
            this.ig.processBatch();
        }
        if (this.fg != null) {
            this.fg.processBatch();
        }
        if (this.mc != null) {
            this.mc.processBatch();
        }
        if (this.ug != null) {
            this.ug.processBatch();
        }
        if (this.og != null) {
            this.og.processBatch();
        }
    }

    public void setPeepholeFullOnly(boolean fullOnly) {
        WeightBank bank;
        int i;
        if (this.ig != null) {
            for (i = 0; i < this.ig.nUpstreamWeights(); ++i) {
                bank = this.ig.getUpstreamWeights(i);
                if (!(bank instanceof DiagonalWeightBank) || !bank.getName().equals(this.name + "IGPeephole")) continue;
                ((DiagonalWeightBank)bank).setFullOnly(fullOnly);
            }
        }
        if (this.fg != null) {
            for (i = 0; i < this.fg.nUpstreamWeights(); ++i) {
                bank = this.fg.getUpstreamWeights(i);
                if (!(bank instanceof DiagonalWeightBank) || !bank.getName().equals(this.name + "FGPeephole")) continue;
                ((DiagonalWeightBank)bank).setFullOnly(fullOnly);
            }
        }
        if (this.og != null) {
            for (i = 0; i < this.og.nUpstreamWeights(); ++i) {
                bank = this.og.getUpstreamWeights(i);
                if (!(bank instanceof DiagonalWeightBank) || !bank.getName().equals(this.name + "OGPeephole")) continue;
                ((DiagonalWeightBank)bank).setFullOnly(fullOnly);
            }
        }
    }

    @Override
    public void setWeightInitializer(WeightInitializer win) {
        super.setWeightInitializer(win);
        if (this.is != null) {
            this.is.setWeightInitializer(win);
        }
        if (this.ig != null) {
            this.ig.setWeightInitializer(win);
        }
        if (this.fg != null) {
            this.fg.setWeightInitializer(win);
        }
        if (this.ug != null) {
            this.ug.setWeightInitializer(win);
        }
        if (this.mc != null) {
            this.mc.setWeightInitializer(win);
        }
        if (this.og != null) {
            this.og.setWeightInitializer(win);
        }
    }

    @Override
    public void setWeightUpdaterType(WeightUpdaterType wut) {
        super.setWeightUpdaterType(wut);
        if (this.is != null) {
            this.is.setWeightUpdaterType(wut);
        }
        if (this.ig != null) {
            this.ig.setWeightUpdaterType(wut);
        }
        if (this.fg != null) {
            this.fg.setWeightUpdaterType(wut);
        }
        if (this.mc != null) {
            this.mc.setWeightUpdaterType(wut);
        }
        if (this.ug != null) {
            this.ug.setWeightUpdaterType(wut);
        }
        if (this.og != null) {
            this.og.setWeightUpdaterType(wut);
        }
    }

    public int size() {
        return this.size;
    }

    @Override
    public void toString(NetworkStringBuilder sb) {
        if (sb.showIntermediate()) {
            super.toString(sb);
            sb.pushIndent();
            for (int i = this.activate.length - 1; i >= 0; --i) {
                this.activate[i].toString(sb);
            }
            for (WeightBank bank : this.conn) {
                bank.toString(sb);
            }
            sb.popIndent();
        } else {
            super.toString(sb);
            sb.pushIndent();
            if (this.og != null) {
                this.og.toString(sb);
            }
            if (this.mc != null) {
                this.mc.toString(sb);
            }
            if (this.mc_state != null) {
                this.mc_state.toString(sb);
            }
            for (WeightBank bank : this.conn) {
                bank.getConnection().toString(sb);
            }
            if (this.fg != null) {
                this.fg.toString(sb);
            }
            if (this.ig != null) {
                this.ig.toString(sb);
            }
            sb.popIndent();
        }
    }

    @Override
    public void truncate(boolean truncate) {
        super.truncate(truncate);
        if (this.ug != null) {
            this.ug.truncate(truncate);
        }
        if (this.mc != null) {
            this.mc.truncate(truncate);
        }
        this.truncateGates(truncate);
    }

    public void truncateGates(boolean truncate) {
        this.truncateGates = truncate;
        if (this.ig != null) {
            this.ig.truncate(this.truncateGates);
        }
        if (this.fg != null) {
            this.fg.truncate(this.truncateGates);
        }
        if (this.og != null) {
            this.og.truncate(this.truncateGates);
        }
    }

    @Override
    public void unbuild() {
        super.unbuild();
        for (Component component : this.activate) {
            component.unbuild();
        }
    }

    @Override
    public void updateEligibilities() {
        for (int i = this.activate.length - 1; i >= 0; --i) {
            this.activate[i].updateEligibilities();
        }
        super.updateEligibilities();
    }

    @Override
    public void updateResponsibilities() {
        for (int i = this.activate.length - 1; i >= 0; --i) {
            this.activate[i].updateResponsibilities();
        }
        super.updateResponsibilities();
    }

    @Override
    public void updateWeights() {
        for (int i = this.activate.length - 1; i >= 0; --i) {
            this.activate[i].updateWeights();
        }
        super.updateWeights();
    }
}

