package dmonner.xlbp.compound;

import dmonner.xlbp.Component;
import dmonner.xlbp.NetworkCopier;
import dmonner.xlbp.UniformWeightInitializer;
import dmonner.xlbp.UpstreamComponent;
import dmonner.xlbp.WeightInitializer;
import dmonner.xlbp.WeightUpdaterType;
import dmonner.xlbp.connection.ConnectionType;
import dmonner.xlbp.layer.UpstreamLayer;
import java.util.ArrayList;
import java.util.Arrays;

/* loaded from: input_file:dmonner/xlbp/compound/AbstractWeightedCompound.class */
public abstract class AbstractWeightedCompound extends AbstractInternalCompound implements WeightedCompound {
    private static final long serialVersionUID = 1;
    protected WeightInitializer win;
    protected WeightUpdaterType wut;
    protected WeightBank[] conn;
    private Boolean truncate;

    public AbstractWeightedCompound(AbstractWeightedCompound abstractWeightedCompound, NetworkCopier networkCopier) {
        super(abstractWeightedCompound, networkCopier);
        this.win = abstractWeightedCompound.win;
        this.wut = abstractWeightedCompound.wut;
        this.truncate = abstractWeightedCompound.truncate;
        this.conn = new WeightBank[0];
        for (WeightBank weightBank : abstractWeightedCompound.conn) {
            networkCopier.addWeightBank(weightBank);
        }
    }

    public AbstractWeightedCompound(String str) {
        super(str);
        this.win = new UniformWeightInitializer();
        this.wut = WeightUpdaterType.basic();
        this.truncate = null;
        this.conn = new WeightBank[0];
    }

    @Override // dmonner.xlbp.Component
    public void activateTest() {
        for (WeightBank weightBank : this.conn) {
            weightBank.activateTest();
        }
    }

    @Override // dmonner.xlbp.Component
    public void activateTrain() {
        for (WeightBank weightBank : this.conn) {
            weightBank.activateTrain();
        }
    }

    @Override // dmonner.xlbp.compound.WeightedCompound
    public void addUpstream(UpstreamComponent upstreamComponent, ConnectionType connectionType) {
        if (connectionType == ConnectionType.WEIGHTED) {
            addUpstreamWeights(upstreamComponent);
            return;
        }
        if (connectionType == ConnectionType.DIRECT) {
            addUpstream(upstreamComponent);
        } else if (connectionType == ConnectionType.INDIRECT) {
            addUpstreamWeights(new IndirectWeightBank(upstreamComponent.getName() + "IndirectTo" + this.name, upstreamComponent.asUpstreamLayer(), this.in, this.win, this.wut));
        } else {
            if (connectionType != ConnectionType.DIAGONAL) {
                throw new IllegalArgumentException("Unhandled ConnectionType: " + connectionType);
            }
            addUpstreamWeights(new DiagonalWeightBank(upstreamComponent.getName() + "DiagonalTo" + this.name, upstreamComponent.asUpstreamLayer(), this.in, this.win, this.wut));
        }
    }

    private void addUpstreamWeights(String str, UpstreamLayer upstreamLayer) {
        addUpstreamWeights(new WeightBank(str + "To" + this.name, upstreamLayer, this.in, this.win, this.wut));
    }

    @Override // dmonner.xlbp.compound.WeightedCompound
    public void addUpstreamWeights(UpstreamComponent upstreamComponent) {
        addUpstreamWeights(upstreamComponent.getName(), upstreamComponent.asUpstreamLayer());
    }

    public void addUpstreamWeights(WeightBank weightBank) {
        weightBank.setWeightUpdaterType(this.wut);
        weightBank.setWeightInitializer(this.win);
        if (this.truncate != null) {
            weightBank.truncate(this.truncate.booleanValue());
        }
        int length = this.conn.length;
        this.conn = (WeightBank[]) Arrays.copyOf(this.conn, this.conn.length + 1);
        this.conn[length] = weightBank;
    }

    @Override // dmonner.xlbp.compound.AbstractCompound, dmonner.xlbp.Component
    public void build() {
        if (this.built) {
            return;
        }
        super.build();
        for (WeightBank weightBank : this.conn) {
            weightBank.build();
        }
        this.built = true;
    }

    @Override // dmonner.xlbp.Component
    public void clearActivations() {
        for (WeightBank weightBank : this.conn) {
            weightBank.clearActivations();
        }
    }

    @Override // dmonner.xlbp.Component
    public void clearEligibilities() {
        for (WeightBank weightBank : this.conn) {
            weightBank.clearEligibilities();
        }
    }

    @Override // dmonner.xlbp.Component
    public void clearResponsibilities() {
        for (WeightBank weightBank : this.conn) {
            weightBank.clearResponsibilities();
        }
    }

    @Override // dmonner.xlbp.compound.AbstractInternalCompound, dmonner.xlbp.compound.AbstractCompound, dmonner.xlbp.compound.Compound, dmonner.xlbp.UpstreamComponent, dmonner.xlbp.Component
    public abstract AbstractWeightedCompound copy(NetworkCopier networkCopier);

    @Override // dmonner.xlbp.compound.AbstractInternalCompound, dmonner.xlbp.compound.AbstractCompound, dmonner.xlbp.compound.Compound, dmonner.xlbp.UpstreamComponent, dmonner.xlbp.Component
    public AbstractWeightedCompound copy(String str) {
        NetworkCopier networkCopier = new NetworkCopier(str);
        AbstractWeightedCompound copy = copy(networkCopier);
        networkCopier.build();
        return copy;
    }

    @Override // dmonner.xlbp.compound.AbstractInternalCompound, dmonner.xlbp.compound.AbstractCompound, dmonner.xlbp.Component
    public void copyConnectivityFrom(Component component, NetworkCopier networkCopier) {
        super.copyConnectivityFrom(component, networkCopier);
        if (component instanceof AbstractWeightedCompound) {
            AbstractWeightedCompound abstractWeightedCompound = (AbstractWeightedCompound) component;
            ArrayList arrayList = new ArrayList(abstractWeightedCompound.conn.length);
            for (WeightBank weightBank : abstractWeightedCompound.conn) {
                if (networkCopier.copyExists(weightBank)) {
                    arrayList.add(networkCopier.getCopyOf(weightBank));
                }
            }
            this.conn = (WeightBank[]) arrayList.toArray(new WeightBank[arrayList.size()]);
        }
    }

    @Override // dmonner.xlbp.compound.WeightedCompound
    public WeightBank getUpstreamWeights() {
        return getUpstreamWeights(0);
    }

    @Override // dmonner.xlbp.compound.WeightedCompound
    public WeightBank getUpstreamWeights(int i) {
        return this.conn[i];
    }

    @Override // dmonner.xlbp.compound.WeightedCompound
    public int nUpstreamWeights() {
        return this.conn.length;
    }

    @Override // dmonner.xlbp.Component
    public int nWeights() {
        int i = 0;
        for (WeightBank weightBank : this.conn) {
            i += weightBank.nWeights();
        }
        return i;
    }

    @Override // dmonner.xlbp.compound.AbstractCompound, dmonner.xlbp.Component
    public boolean optimize() {
        if (!super.optimize()) {
            return false;
        }
        if (this.in == null) {
            throw new IllegalStateException("Missing input layer.");
        }
        for (WeightBank weightBank : this.conn) {
            weightBank.optimize();
        }
        return true;
    }

    @Override // dmonner.xlbp.Component
    public void processBatch() {
        for (WeightBank weightBank : this.conn) {
            weightBank.processBatch();
        }
    }

    @Override // dmonner.xlbp.Component
    public void setWeightInitializer(WeightInitializer weightInitializer) {
        this.win = weightInitializer;
        for (WeightBank weightBank : this.conn) {
            weightBank.setWeightInitializer(weightInitializer);
        }
    }

    @Override // dmonner.xlbp.Component
    public void setWeightUpdaterType(WeightUpdaterType weightUpdaterType) {
        this.wut = weightUpdaterType;
        for (WeightBank weightBank : this.conn) {
            weightBank.setWeightUpdaterType(weightUpdaterType);
        }
    }

    public void truncate(boolean z) {
        this.truncate = Boolean.valueOf(z);
        for (WeightBank weightBank : this.conn) {
            weightBank.truncate(z);
        }
    }

    @Override // dmonner.xlbp.compound.AbstractCompound, dmonner.xlbp.Component
    public void unbuild() {
        super.unbuild();
        for (WeightBank weightBank : this.conn) {
            weightBank.unbuild();
        }
    }

    @Override // dmonner.xlbp.Component
    public void updateEligibilities() {
        for (WeightBank weightBank : this.conn) {
            weightBank.updateEligibilities();
        }
    }

    @Override // dmonner.xlbp.Component
    public void updateResponsibilities() {
        for (WeightBank weightBank : this.conn) {
            weightBank.updateResponsibilities();
        }
    }

    @Override // dmonner.xlbp.Component
    public void updateWeights() {
        for (WeightBank weightBank : this.conn) {
            weightBank.updateWeights();
        }
    }
}
