/*
 * Decompiled with CFR 0.152.
 */
package dmonner.xlbp.compound;

import dmonner.xlbp.Component;
import dmonner.xlbp.DownstreamComponent;
import dmonner.xlbp.NetworkCopier;
import dmonner.xlbp.NetworkStringBuilder;
import dmonner.xlbp.UpstreamComponent;
import dmonner.xlbp.WeightInitializer;
import dmonner.xlbp.WeightUpdaterType;
import dmonner.xlbp.compound.AbstractInternalCompound;
import dmonner.xlbp.compound.FunctionCompound;
import dmonner.xlbp.compound.LinearCompound;
import dmonner.xlbp.compound.LogisticCompound;
import dmonner.xlbp.compound.TanhCompound;
import dmonner.xlbp.layer.CopyDestinationLayer;
import dmonner.xlbp.layer.CopySourceLayer;
import dmonner.xlbp.layer.FanOutLayer;
import dmonner.xlbp.layer.PiLayer;
import dmonner.xlbp.layer.SigmaLayer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;

public class MultiBindCompound
extends AbstractInternalCompound {
    private static final long serialVersionUID = 1L;
    private final int size;
    private Component[] activate;
    private FunctionCompound[] inbind;
    private FunctionCompound[] membind;
    private FunctionCompound[] outbind;
    private final FunctionCompound squash;
    private final CopySourceLayer memsrc;
    private final CopyDestinationLayer memdst;
    private final SigmaLayer state;
    private final PiLayer inpi;
    private final PiLayer mempi;
    private final PiLayer outpi;
    private final FanOutLayer fan;

    public MultiBindCompound(MultiBindCompound that, NetworkCopier copier) {
        super(that, copier);
        int i;
        this.size = that.size;
        this.inbind = new FunctionCompound[that.inbind.length];
        for (i = 0; i < that.inbind.length; ++i) {
            this.inbind[i] = copier.getCopyOf(that.inbind[i]);
        }
        this.membind = new FunctionCompound[that.membind.length];
        for (i = 0; i < that.membind.length; ++i) {
            this.membind[i] = copier.getCopyOf(that.membind[i]);
        }
        this.outbind = new FunctionCompound[that.outbind.length];
        for (i = 0; i < that.outbind.length; ++i) {
            this.outbind[i] = copier.getCopyOf(that.outbind[i]);
        }
        this.squash = copier.getCopyOf(that.squash);
        this.memsrc = copier.getCopyOf(that.memsrc);
        this.memdst = copier.getCopyOf(that.memdst);
        this.state = copier.getCopyOf(that.state);
        this.inpi = copier.getCopyOf(that.inpi);
        this.mempi = copier.getCopyOf(that.mempi);
        this.outpi = copier.getCopyOf(that.outpi);
        this.fan = copier.getCopyOf(that.fan);
        this.activate = new Component[that.activate.length];
        for (i = 0; i < that.activate.length; ++i) {
            this.activate[i] = copier.getCopyOf(that.activate[i]);
        }
        this.in = copier.getCopyOf(that.in);
        this.out = copier.getCopyOf(that.out);
    }

    public MultiBindCompound(String name, int size, String binds) {
        this(name, size, binds, null);
    }

    public MultiBindCompound(String name, int size, String binds, String fcns) {
        super(name);
        this.size = size;
        this.inbind = new FunctionCompound[0];
        this.membind = new FunctionCompound[0];
        this.outbind = new FunctionCompound[0];
        this.squash = new LogisticCompound(name + "Squash", size);
        this.memsrc = new CopySourceLayer(name + "MemSrc", size);
        this.memdst = new CopyDestinationLayer(name + "MemDst", this.memsrc);
        this.state = new SigmaLayer(name + "State", size);
        this.inpi = new PiLayer(name + "InPi", size);
        this.mempi = new PiLayer(name + "MemPi", size);
        this.outpi = new PiLayer(name + "OutPi", size);
        this.fan = new FanOutLayer(name + "FanOut", size);
        this.fan.addUpstream(this.outpi);
        this.outpi.addUpstream(this.squash);
        this.squash.addUpstream(this.memsrc);
        this.memsrc.addUpstream(this.state);
        this.state.addUpstream(this.mempi);
        this.state.addUpstream(this.inpi);
        this.mempi.addUpstream(this.memdst);
        this.addBinds(binds, fcns);
        this.in = this.inpi;
        this.out = this.fan;
    }

    @Override
    public void activateTest() {
        for (int i = 0; i < this.activate.length; ++i) {
            this.activate[i].activateTest();
        }
    }

    @Override
    public void activateTrain() {
        for (int i = 0; i < this.activate.length; ++i) {
            this.activate[i].activateTrain();
        }
    }

    public void addBind(char bind, char fcn) {
        switch (bind) {
            case 'I': 
            case 'i': {
                this.addInputBind(fcn);
                break;
            }
            case 'F': 
            case 'M': 
            case 'f': 
            case 'm': {
                this.addMemoryBind(fcn);
                break;
            }
            case 'O': 
            case 'o': {
                this.addOutputBind(fcn);
                break;
            }
            default: {
                throw new IllegalArgumentException("Unhandled binding type: " + fcn);
            }
        }
    }

    public void addBinds(String binds) {
        this.addBinds(binds, null);
    }

    public void addBinds(String binds, String fcns) {
        int fcnspos = 0;
        for (int i = 0; i < binds.length(); ++i) {
            char bind = binds.charAt(i);
            if (bind != 'i' && bind != 'I' && bind != 'm' && bind != 'M' && bind != 'o' && bind != 'O') {
                throw new IllegalArgumentException("Unhandled binding type: " + bind);
            }
            this.addBind(bind, fcns == null ? (char)'L' : fcns.charAt(fcnspos++));
        }
    }

    public void addInputBind(char fcn) {
        this.addInputBind(this.translateFcn(fcn, "InBind" + (this.inbind.length + 1)));
    }

    public void addInputBind(FunctionCompound fc) {
        this.inbind = Arrays.copyOf(this.inbind, this.inbind.length + 1);
        this.inbind[this.inbind.length - 1] = fc;
        this.inpi.addUpstream(fc);
    }

    public void addMemoryBind(char fcn) {
        this.addMemoryBind(this.translateFcn(fcn, "MemBind" + (this.membind.length + 1)));
    }

    public void addMemoryBind(FunctionCompound fc) {
        this.membind = Arrays.copyOf(this.membind, this.membind.length + 1);
        this.membind[this.membind.length - 1] = fc;
        this.mempi.addUpstream(fc);
    }

    public void addOutputBind(char fcn) {
        this.addOutputBind(this.translateFcn(fcn, "OutBind" + (this.outbind.length + 1)));
    }

    public void addOutputBind(FunctionCompound fc) {
        this.outbind = Arrays.copyOf(this.outbind, this.outbind.length + 1);
        this.outbind[this.outbind.length - 1] = fc;
        this.outpi.addUpstream(fc);
    }

    @Override
    public void build() {
        if (!this.built) {
            super.build();
            for (Component component : this.activate) {
                component.build();
            }
            this.built = true;
        }
    }

    @Override
    public void clearActivations() {
        for (int i = 0; i < this.activate.length; ++i) {
            this.activate[i].clearActivations();
        }
    }

    @Override
    public void clearEligibilities() {
        for (int i = 0; i < this.activate.length; ++i) {
            this.activate[i].clearEligibilities();
        }
    }

    @Override
    public void clearResponsibilities() {
        for (int i = 0; i < this.activate.length; ++i) {
            this.activate[i].clearResponsibilities();
        }
    }

    @Override
    public MultiBindCompound copy(NetworkCopier copier) {
        return new MultiBindCompound(this, copier);
    }

    @Override
    public MultiBindCompound copy(String nameSuffix) {
        NetworkCopier copier = new NetworkCopier(nameSuffix);
        MultiBindCompound copy = this.copy(copier);
        copier.build();
        return copy;
    }

    @Override
    public void copyConnectivityFrom(Component comp, NetworkCopier copier) {
        super.copyConnectivityFrom(comp, copier);
        if (comp instanceof MultiBindCompound) {
            int i;
            MultiBindCompound that = (MultiBindCompound)comp;
            this.squash.copyConnectivityFrom(that.squash, copier);
            for (i = 0; i < this.inbind.length; ++i) {
                this.inbind[i].copyConnectivityFrom(that.inbind[i], copier);
            }
            for (i = 0; i < this.membind.length; ++i) {
                this.membind[i].copyConnectivityFrom(that.membind[i], copier);
            }
            for (i = 0; i < this.outbind.length; ++i) {
                this.outbind[i].copyConnectivityFrom(that.outbind[i], copier);
            }
        }
    }

    @Override
    public Component[] getComponents() {
        if (this.activate != null) {
            return (Component[])this.activate.clone();
        }
        ArrayList<UpstreamComponent> comps = new ArrayList<UpstreamComponent>();
        for (FunctionCompound bind : this.inbind) {
            comps.add(bind);
        }
        comps.add(this.inpi);
        for (FunctionCompound bind : this.membind) {
            comps.add(bind);
        }
        comps.add(this.memdst);
        comps.add(this.mempi);
        comps.add(this.state);
        comps.add(this.memsrc);
        comps.add(this.squash);
        for (FunctionCompound bind : this.outbind) {
            comps.add(bind);
        }
        comps.add(this.outpi);
        comps.add(this.fan);
        return comps.toArray(new Component[comps.size()]);
    }

    public FunctionCompound getInputBind(int i) {
        return this.inbind[i];
    }

    public FunctionCompound getMemoryBind(int i) {
        return this.membind[i];
    }

    public FunctionCompound getOutputBind(int i) {
        return this.outbind[i];
    }

    public int nInputBind() {
        return this.inbind.length;
    }

    public int nMemoryBind() {
        return this.membind.length;
    }

    public int nOutputBind() {
        return this.outbind.length;
    }

    @Override
    public int nWeights() {
        int sum = 0;
        sum += this.squash.nWeights();
        for (FunctionCompound bind : this.inbind) {
            sum += bind.nWeights();
        }
        for (FunctionCompound bind : this.membind) {
            sum += bind.nWeights();
        }
        for (FunctionCompound bind : this.outbind) {
            sum += bind.nWeights();
        }
        return sum;
    }

    @Override
    public boolean optimize() {
        if (!super.optimize()) {
            return false;
        }
        ArrayList<UpstreamComponent> act = new ArrayList<UpstreamComponent>();
        for (FunctionCompound bind : this.inbind) {
            act.add(bind);
        }
        act.add(this.inpi);
        for (FunctionCompound bind : this.membind) {
            act.add(bind);
        }
        act.add(this.memdst);
        act.add(this.mempi);
        act.add(this.state);
        act.add(this.memsrc);
        act.add(this.squash);
        for (FunctionCompound bind : this.outbind) {
            act.add(bind);
        }
        act.add(this.outpi);
        act.add(this.fan);
        Iterator it = act.iterator();
        while (it.hasNext()) {
            if (((Component)it.next()).optimize()) continue;
            it.remove();
        }
        this.activate = act.toArray(new Component[act.size()]);
        this.in = ((DownstreamComponent)this.activate[0]).asDownstreamLayer();
        this.out = ((UpstreamComponent)this.activate[this.activate.length - 1]).asUpstreamLayer();
        return true;
    }

    @Override
    public void processBatch() {
        this.squash.processBatch();
        for (FunctionCompound bind : this.inbind) {
            bind.processBatch();
        }
        for (FunctionCompound bind : this.membind) {
            bind.processBatch();
        }
        for (FunctionCompound bind : this.outbind) {
            bind.processBatch();
        }
    }

    @Override
    public void setWeightInitializer(WeightInitializer win) {
        this.squash.setWeightInitializer(win);
        for (FunctionCompound bind : this.inbind) {
            bind.setWeightInitializer(win);
        }
        for (FunctionCompound bind : this.membind) {
            bind.setWeightInitializer(win);
        }
        for (FunctionCompound bind : this.outbind) {
            bind.setWeightInitializer(win);
        }
    }

    @Override
    public void setWeightUpdaterType(WeightUpdaterType wut) {
        this.squash.setWeightUpdaterType(wut);
        for (FunctionCompound bind : this.inbind) {
            bind.setWeightUpdaterType(wut);
        }
        for (FunctionCompound bind : this.membind) {
            bind.setWeightUpdaterType(wut);
        }
        for (FunctionCompound bind : this.outbind) {
            bind.setWeightUpdaterType(wut);
        }
    }

    public int size() {
        return this.size;
    }

    @Override
    public void toString(NetworkStringBuilder sb) {
        super.toString(sb);
        sb.pushIndent();
        for (int i = this.activate.length - 1; i >= 0; --i) {
            this.activate[i].toString(sb);
        }
        sb.popIndent();
    }

    private FunctionCompound translateFcn(char fcn, String nameSuffix) {
        switch (fcn) {
            case 'L': 
            case 'S': 
            case 'l': 
            case 's': {
                return new LogisticCompound(this.name + nameSuffix, this.size);
            }
            case '/': {
                return new LinearCompound(this.name + nameSuffix, this.size);
            }
            case 'T': 
            case 't': {
                return new TanhCompound(this.name + nameSuffix, this.size);
            }
        }
        throw new IllegalArgumentException("Unhandled function type: " + fcn);
    }

    public void truncate(boolean truncate) {
        this.squash.truncate(truncate);
        for (FunctionCompound bind : this.inbind) {
            bind.truncate(truncate);
        }
        for (FunctionCompound bind : this.membind) {
            bind.truncate(truncate);
        }
        for (FunctionCompound bind : this.outbind) {
            bind.truncate(truncate);
        }
    }

    @Override
    public void unbuild() {
        super.unbuild();
        for (Component component : this.activate) {
            component.unbuild();
        }
    }

    @Override
    public void updateEligibilities() {
        for (int i = this.activate.length - 1; i >= 0; --i) {
            this.activate[i].updateEligibilities();
        }
    }

    @Override
    public void updateResponsibilities() {
        for (int i = this.activate.length - 1; i >= 0; --i) {
            this.activate[i].updateResponsibilities();
        }
    }

    @Override
    public void updateWeights() {
        for (int i = this.activate.length - 1; i >= 0; --i) {
            this.activate[i].updateWeights();
        }
    }
}

