package dmonner.xlbp.compound;

import dmonner.xlbp.Component;
import dmonner.xlbp.DownstreamComponent;
import dmonner.xlbp.NetworkCopier;
import dmonner.xlbp.NetworkStringBuilder;
import dmonner.xlbp.UpstreamComponent;
import dmonner.xlbp.WeightInitializer;
import dmonner.xlbp.WeightUpdaterType;
import dmonner.xlbp.layer.CopyDestinationLayer;
import dmonner.xlbp.layer.CopySourceLayer;
import dmonner.xlbp.layer.DownstreamLayer;
import dmonner.xlbp.layer.FanOutLayer;
import dmonner.xlbp.layer.PiLayer;
import dmonner.xlbp.layer.SigmaLayer;
import dmonner.xlbp.layer.UpstreamLayer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;

/* loaded from: input_file:dmonner/xlbp/compound/MultiBindCompound.class */
public class MultiBindCompound extends AbstractInternalCompound {
    private static final long serialVersionUID = 1;
    private final int size;
    private Component[] activate;
    private FunctionCompound[] inbind;
    private FunctionCompound[] membind;
    private FunctionCompound[] outbind;
    private final FunctionCompound squash;
    private final CopySourceLayer memsrc;
    private final CopyDestinationLayer memdst;
    private final SigmaLayer state;
    private final PiLayer inpi;
    private final PiLayer mempi;
    private final PiLayer outpi;
    private final FanOutLayer fan;

    public MultiBindCompound(MultiBindCompound multiBindCompound, NetworkCopier networkCopier) {
        super(multiBindCompound, networkCopier);
        this.size = multiBindCompound.size;
        this.inbind = new FunctionCompound[multiBindCompound.inbind.length];
        for (int i = 0; i < multiBindCompound.inbind.length; i++) {
            this.inbind[i] = (FunctionCompound) networkCopier.getCopyOf(multiBindCompound.inbind[i]);
        }
        this.membind = new FunctionCompound[multiBindCompound.membind.length];
        for (int i2 = 0; i2 < multiBindCompound.membind.length; i2++) {
            this.membind[i2] = (FunctionCompound) networkCopier.getCopyOf(multiBindCompound.membind[i2]);
        }
        this.outbind = new FunctionCompound[multiBindCompound.outbind.length];
        for (int i3 = 0; i3 < multiBindCompound.outbind.length; i3++) {
            this.outbind[i3] = (FunctionCompound) networkCopier.getCopyOf(multiBindCompound.outbind[i3]);
        }
        this.squash = (FunctionCompound) networkCopier.getCopyOf(multiBindCompound.squash);
        this.memsrc = (CopySourceLayer) networkCopier.getCopyOf(multiBindCompound.memsrc);
        this.memdst = (CopyDestinationLayer) networkCopier.getCopyOf(multiBindCompound.memdst);
        this.state = (SigmaLayer) networkCopier.getCopyOf(multiBindCompound.state);
        this.inpi = (PiLayer) networkCopier.getCopyOf(multiBindCompound.inpi);
        this.mempi = (PiLayer) networkCopier.getCopyOf(multiBindCompound.mempi);
        this.outpi = (PiLayer) networkCopier.getCopyOf(multiBindCompound.outpi);
        this.fan = (FanOutLayer) networkCopier.getCopyOf(multiBindCompound.fan);
        this.activate = new Component[multiBindCompound.activate.length];
        for (int i4 = 0; i4 < multiBindCompound.activate.length; i4++) {
            this.activate[i4] = networkCopier.getCopyOf(multiBindCompound.activate[i4]);
        }
        this.in = (DownstreamLayer) networkCopier.getCopyOf(multiBindCompound.in);
        this.out = (UpstreamLayer) networkCopier.getCopyOf(multiBindCompound.out);
    }

    public MultiBindCompound(String str, int i, String str2) {
        this(str, i, str2, null);
    }

    public MultiBindCompound(String str, int i, String str2, String str3) {
        super(str);
        this.size = i;
        this.inbind = new FunctionCompound[0];
        this.membind = new FunctionCompound[0];
        this.outbind = new FunctionCompound[0];
        this.squash = new LogisticCompound(str + "Squash", i);
        this.memsrc = new CopySourceLayer(str + "MemSrc", i);
        this.memdst = new CopyDestinationLayer(str + "MemDst", this.memsrc);
        this.state = new SigmaLayer(str + "State", i);
        this.inpi = new PiLayer(str + "InPi", i);
        this.mempi = new PiLayer(str + "MemPi", i);
        this.outpi = new PiLayer(str + "OutPi", i);
        this.fan = new FanOutLayer(str + "FanOut", i);
        this.fan.addUpstream(this.outpi);
        this.outpi.addUpstream(this.squash);
        this.squash.addUpstream(this.memsrc);
        this.memsrc.addUpstream(this.state);
        this.state.addUpstream(this.mempi);
        this.state.addUpstream(this.inpi);
        this.mempi.addUpstream(this.memdst);
        addBinds(str2, str3);
        this.in = this.inpi;
        this.out = this.fan;
    }

    @Override // dmonner.xlbp.Component
    public void activateTest() {
        for (int i = 0; i < this.activate.length; i++) {
            this.activate[i].activateTest();
        }
    }

    @Override // dmonner.xlbp.Component
    public void activateTrain() {
        for (int i = 0; i < this.activate.length; i++) {
            this.activate[i].activateTrain();
        }
    }

    public void addBind(char c, char c2) {
        switch (c) {
            case 'F':
            case 'M':
            case 'f':
            case 'm':
                addMemoryBind(c2);
                return;
            case 'I':
            case 'i':
                addInputBind(c2);
                return;
            case 'O':
            case 'o':
                addOutputBind(c2);
                return;
            default:
                throw new IllegalArgumentException("Unhandled binding type: " + c2);
        }
    }

    public void addBinds(String str) {
        addBinds(str, null);
    }

    public void addBinds(String str, String str2) {
        char charAt;
        int i = 0;
        for (int i2 = 0; i2 < str.length(); i2++) {
            char charAt2 = str.charAt(i2);
            if (charAt2 != 'i' && charAt2 != 'I' && charAt2 != 'm' && charAt2 != 'M' && charAt2 != 'o' && charAt2 != 'O') {
                throw new IllegalArgumentException("Unhandled binding type: " + charAt2);
            }
            if (str2 == null) {
                charAt = 'L';
            } else {
                int i3 = i;
                i++;
                charAt = str2.charAt(i3);
            }
            addBind(charAt2, charAt);
        }
    }

    public void addInputBind(char c) {
        addInputBind(translateFcn(c, "InBind" + (this.inbind.length + 1)));
    }

    public void addInputBind(FunctionCompound functionCompound) {
        this.inbind = (FunctionCompound[]) Arrays.copyOf(this.inbind, this.inbind.length + 1);
        this.inbind[this.inbind.length - 1] = functionCompound;
        this.inpi.addUpstream(functionCompound);
    }

    public void addMemoryBind(char c) {
        addMemoryBind(translateFcn(c, "MemBind" + (this.membind.length + 1)));
    }

    public void addMemoryBind(FunctionCompound functionCompound) {
        this.membind = (FunctionCompound[]) Arrays.copyOf(this.membind, this.membind.length + 1);
        this.membind[this.membind.length - 1] = functionCompound;
        this.mempi.addUpstream(functionCompound);
    }

    public void addOutputBind(char c) {
        addOutputBind(translateFcn(c, "OutBind" + (this.outbind.length + 1)));
    }

    public void addOutputBind(FunctionCompound functionCompound) {
        this.outbind = (FunctionCompound[]) Arrays.copyOf(this.outbind, this.outbind.length + 1);
        this.outbind[this.outbind.length - 1] = functionCompound;
        this.outpi.addUpstream(functionCompound);
    }

    @Override // dmonner.xlbp.compound.AbstractCompound, dmonner.xlbp.Component
    public void build() {
        if (this.built) {
            return;
        }
        super.build();
        for (Component component : this.activate) {
            component.build();
        }
        this.built = true;
    }

    @Override // dmonner.xlbp.Component
    public void clearActivations() {
        for (int i = 0; i < this.activate.length; i++) {
            this.activate[i].clearActivations();
        }
    }

    @Override // dmonner.xlbp.Component
    public void clearEligibilities() {
        for (int i = 0; i < this.activate.length; i++) {
            this.activate[i].clearEligibilities();
        }
    }

    @Override // dmonner.xlbp.Component
    public void clearResponsibilities() {
        for (int i = 0; i < this.activate.length; i++) {
            this.activate[i].clearResponsibilities();
        }
    }

    @Override // dmonner.xlbp.compound.AbstractInternalCompound, dmonner.xlbp.compound.AbstractCompound, dmonner.xlbp.compound.Compound, dmonner.xlbp.UpstreamComponent, dmonner.xlbp.Component
    public MultiBindCompound copy(NetworkCopier networkCopier) {
        return new MultiBindCompound(this, networkCopier);
    }

    @Override // dmonner.xlbp.compound.AbstractInternalCompound, dmonner.xlbp.compound.AbstractCompound, dmonner.xlbp.compound.Compound, dmonner.xlbp.UpstreamComponent, dmonner.xlbp.Component
    public MultiBindCompound copy(String str) {
        NetworkCopier networkCopier = new NetworkCopier(str);
        MultiBindCompound copy = copy(networkCopier);
        networkCopier.build();
        return copy;
    }

    @Override // dmonner.xlbp.compound.AbstractInternalCompound, dmonner.xlbp.compound.AbstractCompound, dmonner.xlbp.Component
    public void copyConnectivityFrom(Component component, NetworkCopier networkCopier) {
        super.copyConnectivityFrom(component, networkCopier);
        if (component instanceof MultiBindCompound) {
            MultiBindCompound multiBindCompound = (MultiBindCompound) component;
            this.squash.copyConnectivityFrom(multiBindCompound.squash, networkCopier);
            for (int i = 0; i < this.inbind.length; i++) {
                this.inbind[i].copyConnectivityFrom(multiBindCompound.inbind[i], networkCopier);
            }
            for (int i2 = 0; i2 < this.membind.length; i2++) {
                this.membind[i2].copyConnectivityFrom(multiBindCompound.membind[i2], networkCopier);
            }
            for (int i3 = 0; i3 < this.outbind.length; i3++) {
                this.outbind[i3].copyConnectivityFrom(multiBindCompound.outbind[i3], networkCopier);
            }
        }
    }

    @Override // dmonner.xlbp.compound.Compound
    public Component[] getComponents() {
        if (this.activate != null) {
            return (Component[]) this.activate.clone();
        }
        ArrayList arrayList = new ArrayList();
        for (FunctionCompound functionCompound : this.inbind) {
            arrayList.add(functionCompound);
        }
        arrayList.add(this.inpi);
        for (FunctionCompound functionCompound2 : this.membind) {
            arrayList.add(functionCompound2);
        }
        arrayList.add(this.memdst);
        arrayList.add(this.mempi);
        arrayList.add(this.state);
        arrayList.add(this.memsrc);
        arrayList.add(this.squash);
        for (FunctionCompound functionCompound3 : this.outbind) {
            arrayList.add(functionCompound3);
        }
        arrayList.add(this.outpi);
        arrayList.add(this.fan);
        return (Component[]) arrayList.toArray(new Component[arrayList.size()]);
    }

    public FunctionCompound getInputBind(int i) {
        return this.inbind[i];
    }

    public FunctionCompound getMemoryBind(int i) {
        return this.membind[i];
    }

    public FunctionCompound getOutputBind(int i) {
        return this.outbind[i];
    }

    public int nInputBind() {
        return this.inbind.length;
    }

    public int nMemoryBind() {
        return this.membind.length;
    }

    public int nOutputBind() {
        return this.outbind.length;
    }

    @Override // dmonner.xlbp.Component
    public int nWeights() {
        int nWeights = 0 + this.squash.nWeights();
        for (FunctionCompound functionCompound : this.inbind) {
            nWeights += functionCompound.nWeights();
        }
        for (FunctionCompound functionCompound2 : this.membind) {
            nWeights += functionCompound2.nWeights();
        }
        for (FunctionCompound functionCompound3 : this.outbind) {
            nWeights += functionCompound3.nWeights();
        }
        return nWeights;
    }

    @Override // dmonner.xlbp.compound.AbstractCompound, dmonner.xlbp.Component
    public boolean optimize() {
        if (!super.optimize()) {
            return false;
        }
        ArrayList arrayList = new ArrayList();
        for (FunctionCompound functionCompound : this.inbind) {
            arrayList.add(functionCompound);
        }
        arrayList.add(this.inpi);
        for (FunctionCompound functionCompound2 : this.membind) {
            arrayList.add(functionCompound2);
        }
        arrayList.add(this.memdst);
        arrayList.add(this.mempi);
        arrayList.add(this.state);
        arrayList.add(this.memsrc);
        arrayList.add(this.squash);
        for (FunctionCompound functionCompound3 : this.outbind) {
            arrayList.add(functionCompound3);
        }
        arrayList.add(this.outpi);
        arrayList.add(this.fan);
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            if (!((Component) it.next()).optimize()) {
                it.remove();
            }
        }
        this.activate = (Component[]) arrayList.toArray(new Component[arrayList.size()]);
        this.in = ((DownstreamComponent) this.activate[0]).asDownstreamLayer();
        this.out = ((UpstreamComponent) this.activate[this.activate.length - 1]).asUpstreamLayer();
        return true;
    }

    @Override // dmonner.xlbp.Component
    public void processBatch() {
        this.squash.processBatch();
        for (FunctionCompound functionCompound : this.inbind) {
            functionCompound.processBatch();
        }
        for (FunctionCompound functionCompound2 : this.membind) {
            functionCompound2.processBatch();
        }
        for (FunctionCompound functionCompound3 : this.outbind) {
            functionCompound3.processBatch();
        }
    }

    @Override // dmonner.xlbp.Component
    public void setWeightInitializer(WeightInitializer weightInitializer) {
        this.squash.setWeightInitializer(weightInitializer);
        for (FunctionCompound functionCompound : this.inbind) {
            functionCompound.setWeightInitializer(weightInitializer);
        }
        for (FunctionCompound functionCompound2 : this.membind) {
            functionCompound2.setWeightInitializer(weightInitializer);
        }
        for (FunctionCompound functionCompound3 : this.outbind) {
            functionCompound3.setWeightInitializer(weightInitializer);
        }
    }

    @Override // dmonner.xlbp.Component
    public void setWeightUpdaterType(WeightUpdaterType weightUpdaterType) {
        this.squash.setWeightUpdaterType(weightUpdaterType);
        for (FunctionCompound functionCompound : this.inbind) {
            functionCompound.setWeightUpdaterType(weightUpdaterType);
        }
        for (FunctionCompound functionCompound2 : this.membind) {
            functionCompound2.setWeightUpdaterType(weightUpdaterType);
        }
        for (FunctionCompound functionCompound3 : this.outbind) {
            functionCompound3.setWeightUpdaterType(weightUpdaterType);
        }
    }

    public int size() {
        return this.size;
    }

    @Override // dmonner.xlbp.compound.AbstractCompound, dmonner.xlbp.Component
    public void toString(NetworkStringBuilder networkStringBuilder) {
        super.toString(networkStringBuilder);
        networkStringBuilder.pushIndent();
        for (int length = this.activate.length - 1; length >= 0; length--) {
            this.activate[length].toString(networkStringBuilder);
        }
        networkStringBuilder.popIndent();
    }

    private FunctionCompound translateFcn(char c, String str) {
        switch (c) {
            case '/':
                return new LinearCompound(this.name + str, this.size);
            case 'L':
            case 'S':
            case 'l':
            case 's':
                return new LogisticCompound(this.name + str, this.size);
            case 'T':
            case 't':
                return new TanhCompound(this.name + str, this.size);
            default:
                throw new IllegalArgumentException("Unhandled function type: " + c);
        }
    }

    public void truncate(boolean z) {
        this.squash.truncate(z);
        for (FunctionCompound functionCompound : this.inbind) {
            functionCompound.truncate(z);
        }
        for (FunctionCompound functionCompound2 : this.membind) {
            functionCompound2.truncate(z);
        }
        for (FunctionCompound functionCompound3 : this.outbind) {
            functionCompound3.truncate(z);
        }
    }

    @Override // dmonner.xlbp.compound.AbstractCompound, dmonner.xlbp.Component
    public void unbuild() {
        super.unbuild();
        for (Component component : this.activate) {
            component.unbuild();
        }
    }

    @Override // dmonner.xlbp.Component
    public void updateEligibilities() {
        for (int length = this.activate.length - 1; length >= 0; length--) {
            this.activate[length].updateEligibilities();
        }
    }

    @Override // dmonner.xlbp.Component
    public void updateResponsibilities() {
        for (int length = this.activate.length - 1; length >= 0; length--) {
            this.activate[length].updateResponsibilities();
        }
    }

    @Override // dmonner.xlbp.Component
    public void updateWeights() {
        for (int length = this.activate.length - 1; length >= 0; length--) {
            this.activate[length].updateWeights();
        }
    }
}
