package dmonner.xlbp;

import dmonner.xlbp.compound.Compound;
import dmonner.xlbp.compound.InputCompound;
import dmonner.xlbp.compound.MemoryCellCompound;
import dmonner.xlbp.compound.TargetCompound;
import dmonner.xlbp.compound.WeightBank;
import dmonner.xlbp.compound.WeightedCompound;
import dmonner.xlbp.compound.XEntropyTargetCompound;
import dmonner.xlbp.layer.InputLayer;
import dmonner.xlbp.layer.Layer;
import dmonner.xlbp.layer.TargetLayer;
import java.util.Arrays;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

/* loaded from: input_file:dmonner/xlbp/Network.class */
public class Network implements Component {
    private static final long serialVersionUID = 1;
    private final String name;
    private Component[] all;
    private Component[] activate;
    private Component[] train;
    private DownstreamComponent[] directEntry;
    private UpstreamComponent[] directExit;
    private WeightedCompound[] weightedEntry;
    private InputLayer[] input;
    private TargetLayer[] target;
    private Network[] subnet;
    private int nAll;
    private int nActivate;
    private int nTrain;
    private int nInput;
    private int nTarget;
    private int nSubnet;
    private WeightInitializer win;
    private WeightUpdaterType wut;
    private boolean built;

    public static void main(String[] strArr) {
        InputCompound inputCompound = new InputCompound("Input", 300);
        MemoryCellCompound memoryCellCompound = new MemoryCellCompound("Hidden", 300, "IFOPL");
        XEntropyTargetCompound xEntropyTargetCompound = new XEntropyTargetCompound("Output", 300);
        xEntropyTargetCompound.addUpstreamWeights(memoryCellCompound);
        memoryCellCompound.addUpstreamWeights(inputCompound);
        Network network = new Network("TheNet");
        network.setWeightUpdaterType(WeightUpdaterType.basic(0.1f));
        network.setWeightInitializer(new UniformWeightInitializer(1.0f, -0.1f, 0.1f));
        network.add(inputCompound);
        network.add(memoryCellCompound);
        network.add(xEntropyTargetCompound);
        network.optimize();
        network.build();
        System.out.println("Total weights: " + network.nWeights() + "\n");
        float[] fArr = new float[300];
        float[] fArr2 = new float[300];
        System.out.println("Each row represents training the network for 1000 trials.");
        System.out.println("Each entry specifies the time required in milliseconds.\n");
        System.out.println("clear\tactiv\tresp\tweights\t/ total");
        for (int i = 0; i < 10; i++) {
            network.clear();
            long j = 0;
            long j2 = 0;
            long j3 = 0;
            long j4 = 0;
            Date date = new Date();
            for (int i2 = 0; i2 < 1000; i2++) {
                Date date2 = new Date();
                Date date3 = new Date();
                for (int i3 = 0; i3 < 300; i3++) {
                    fArr[i3] = (float) Math.random();
                }
                inputCompound.setInput(fArr);
                network.activateTrain();
                network.updateEligibilities();
                Date date4 = new Date();
                for (int i4 = 0; i4 < 300; i4++) {
                    fArr2[i4] = (float) Math.random();
                }
                xEntropyTargetCompound.setTarget(fArr2);
                network.updateResponsibilities();
                Date date5 = new Date();
                network.updateWeights();
                Date date6 = new Date();
                j4 += date3.getTime() - date2.getTime();
                j += date4.getTime() - date3.getTime();
                j2 += date5.getTime() - date4.getTime();
                j3 += date6.getTime() - date5.getTime();
            }
            System.out.println(j4 + "\t" + j + "\t" + j2 + "\t" + j3 + "\t/ " + (new Date().getTime() - date.getTime()));
        }
    }

    public Network(Network network, NetworkCopier networkCopier) {
        this.name = networkCopier.getCopyNameFrom(network);
        this.nAll = network.nAll;
        this.nActivate = network.nActivate;
        this.nTrain = network.nTrain;
        this.nInput = network.nInput;
        this.nTarget = network.nTarget;
        this.nSubnet = network.nSubnet;
        this.win = network.win;
        this.wut = network.wut;
        this.all = new Component[network.all.length];
        this.activate = new Component[network.activate.length];
        this.train = new Component[network.train.length];
        this.input = new InputLayer[network.input.length];
        this.target = new TargetLayer[network.target.length];
        this.subnet = new Network[network.subnet.length];
        this.directEntry = new DownstreamComponent[network.directEntry.length];
        this.directExit = new UpstreamComponent[network.directExit.length];
        this.weightedEntry = new WeightedCompound[network.weightedEntry.length];
        for (int i = 0; i < network.all.length; i++) {
            this.all[i] = networkCopier.getCopyOf(network.all[i]);
        }
        for (int i2 = 0; i2 < network.activate.length; i2++) {
            this.activate[i2] = networkCopier.getCopyOf(network.activate[i2]);
        }
        for (int i3 = 0; i3 < network.train.length; i3++) {
            this.train[i3] = networkCopier.getCopyOf(network.train[i3]);
        }
        for (int i4 = 0; i4 < network.input.length; i4++) {
            this.input[i4] = (InputLayer) networkCopier.getCopyOf(network.input[i4]);
        }
        for (int i5 = 0; i5 < network.target.length; i5++) {
            this.target[i5] = (TargetLayer) networkCopier.getCopyOf(network.target[i5]);
        }
        for (int i6 = 0; i6 < network.subnet.length; i6++) {
            this.subnet[i6] = (Network) networkCopier.getCopyOf(network.subnet[i6]);
        }
        for (int i7 = 0; i7 < network.directEntry.length; i7++) {
            this.directEntry[i7] = (DownstreamComponent) networkCopier.getCopyOf(network.directEntry[i7]);
        }
        for (int i8 = 0; i8 < network.directExit.length; i8++) {
            this.directExit[i8] = (UpstreamComponent) networkCopier.getCopyOf(network.directExit[i8]);
        }
        for (int i9 = 0; i9 < network.weightedEntry.length; i9++) {
            this.weightedEntry[i9] = (WeightedCompound) networkCopier.getCopyOf(network.weightedEntry[i9]);
        }
    }

    public Network(String str) {
        this(str, 0, 0, 0, 0, 0, 0);
    }

    public Network(String str, int i, int i2, int i3, int i4, int i5, int i6) {
        this.name = str;
        this.all = new Component[i];
        this.activate = new Component[i2];
        this.train = new Component[i3];
        this.input = new InputLayer[i4];
        this.target = new TargetLayer[i5];
        this.subnet = new Network[i6];
        this.directEntry = new DownstreamComponent[0];
        this.directExit = new UpstreamComponent[0];
        this.weightedEntry = new WeightedCompound[0];
        this.win = new UniformWeightInitializer();
        this.wut = WeightUpdaterType.basic();
    }

    @Override // dmonner.xlbp.Component
    public void activateTest() {
        for (int i = 0; i < this.nActivate; i++) {
            this.activate[i].activateTest();
        }
    }

    @Override // dmonner.xlbp.Component
    public void activateTrain() {
        for (int i = 0; i < this.nActivate; i++) {
            this.activate[i].clearResponsibilities();
        }
        for (int i2 = 0; i2 < this.nActivate; i2++) {
            this.activate[i2].activateTrain();
        }
    }

    public void add(Component component) {
        add(component, true, true, false, false);
    }

    public void add(Component component, boolean z, boolean z2, boolean z3, boolean z4) {
        component.setWeightInitializer(this.win);
        component.setWeightUpdaterType(this.wut);
        if (component instanceof Network) {
            addSubnet((Network) component);
        }
        if (component instanceof InputCompound) {
            addInput(((InputCompound) component).getInputLayer());
        }
        if (component instanceof InputLayer) {
            addInput((InputLayer) component);
        }
        if (component instanceof TargetCompound) {
            addTarget(((TargetCompound) component).getTargetLayer());
        }
        if (component instanceof TargetLayer) {
            addTarget((TargetLayer) component);
        }
        if (z) {
            addActivate(component);
        }
        if (z2) {
            addTrain(component);
        }
        if (z3) {
            addEntry(component);
        }
        if (z4) {
            addExit(component);
        }
        addAll(component);
    }

    private void addActivate(Component component) {
        ensureActivateCapacity(this.nActivate + 1);
        Component[] componentArr = this.activate;
        int i = this.nActivate;
        this.nActivate = i + 1;
        componentArr[i] = component;
    }

    public void addActivateOnly(Component component) {
        add(component, true, false, false, false);
    }

    private void addAll(Component component) {
        ensureAllCapacity(this.nAll + 1);
        Component[] componentArr = this.all;
        int i = this.nAll;
        this.nAll = i + 1;
        componentArr[i] = component;
    }

    private void addEntry(Component component) {
        if (component instanceof WeightedCompound) {
            ensureWeightedEntryCapacity(this.weightedEntry.length + 1);
            this.weightedEntry[this.weightedEntry.length - 1] = (WeightedCompound) component;
        }
        if (component instanceof DownstreamComponent) {
            ensureDirectEntryCapacity(this.directEntry.length + 1);
            this.directEntry[this.directEntry.length - 1] = (DownstreamComponent) component;
        }
    }

    private void addExit(Component component) {
        if (component instanceof UpstreamComponent) {
            ensureDirectExitCapacity(this.directExit.length + 1);
            this.directExit[this.directExit.length - 1] = (UpstreamComponent) component;
        }
    }

    private void addInput(InputLayer inputLayer) {
        ensureInputCapacity(this.nInput + 1);
        InputLayer[] inputLayerArr = this.input;
        int i = this.nInput;
        this.nInput = i + 1;
        inputLayerArr[i] = inputLayer;
    }

    private void addSubnet(Network network) {
        ensureSubnetCapacity(this.nSubnet + 1);
        Network[] networkArr = this.subnet;
        int i = this.nSubnet;
        this.nSubnet = i + 1;
        networkArr[i] = network;
        for (InputLayer inputLayer : network.input) {
            addInput(inputLayer);
        }
        for (TargetLayer targetLayer : network.target) {
            addTarget(targetLayer);
        }
    }

    private void addTarget(TargetLayer targetLayer) {
        ensureTargetCapacity(this.nTarget + 1);
        TargetLayer[] targetLayerArr = this.target;
        int i = this.nTarget;
        this.nTarget = i + 1;
        targetLayerArr[i] = targetLayer;
    }

    private void addTrain(Component component) {
        ensureTrainCapacity(this.nTrain + 1);
        Component[] componentArr = this.train;
        int i = this.nTrain;
        this.nTrain = i + 1;
        componentArr[i] = component;
    }

    public void addTrainOnly(Component component) {
        add(component, false, true, false, false);
    }

    public void addUpstream(UpstreamComponent upstreamComponent) {
        for (DownstreamComponent downstreamComponent : this.directEntry) {
            downstreamComponent.addUpstream(upstreamComponent);
        }
    }

    public void addUpstream(UpstreamComponent upstreamComponent, boolean z) {
        if (z) {
            addUpstreamWeights(upstreamComponent);
        } else {
            addUpstream(upstreamComponent);
        }
    }

    public void addUpstreamWeights(UpstreamComponent upstreamComponent) {
        for (WeightedCompound weightedCompound : this.weightedEntry) {
            weightedCompound.addUpstreamWeights(upstreamComponent);
        }
    }

    @Override // dmonner.xlbp.Component
    public void build() {
        if (!this.built) {
            for (int i = 0; i < this.nAll; i++) {
                this.all[i].build();
            }
        }
        this.built = true;
    }

    @Override // dmonner.xlbp.Component
    public void clear() {
        clearActivations();
        clearEligibilities();
        clearResponsibilities();
    }

    @Override // dmonner.xlbp.Component
    public void clearActivations() {
        for (int i = 0; i < this.nAll; i++) {
            this.all[i].clearActivations();
        }
    }

    @Override // dmonner.xlbp.Component
    public void clearEligibilities() {
        for (int i = 0; i < this.nAll; i++) {
            this.all[i].clearEligibilities();
        }
    }

    public void clearInputs() {
        for (int i = 0; i < this.nInput; i++) {
            this.input[i].clear();
        }
        for (Network network : this.subnet) {
            network.clearInputs();
        }
    }

    @Override // dmonner.xlbp.Component
    public void clearResponsibilities() {
        for (int i = 0; i < this.nAll; i++) {
            this.all[i].clearResponsibilities();
        }
    }

    @Override // java.lang.Comparable
    public int compareTo(Component component) {
        return this.name.compareTo(component.getName());
    }

    @Override // dmonner.xlbp.Component
    public Network copy(NetworkCopier networkCopier) {
        return new Network(this, networkCopier);
    }

    @Override // dmonner.xlbp.Component
    public Network copy(String str) {
        return copy(str, false, false);
    }

    public Network copy(String str, boolean z, boolean z2) {
        return copy("", str, z, z2);
    }

    public Network copy(String str, String str2, boolean z, boolean z2) {
        NetworkCopier networkCopier = new NetworkCopier(str, str2, z, z2);
        Network copy = copy(networkCopier);
        networkCopier.build();
        return copy;
    }

    @Override // dmonner.xlbp.Component
    public void copyConnectivityFrom(Component component, NetworkCopier networkCopier) {
    }

    public void ensureActivateCapacity(int i) {
        if (i >= this.activate.length) {
            this.activate = (Component[]) Arrays.copyOf(this.activate, i);
        }
    }

    public void ensureAllCapacity(int i) {
        if (i >= this.all.length) {
            this.all = (Component[]) Arrays.copyOf(this.all, i);
        }
    }

    public void ensureDirectEntryCapacity(int i) {
        if (i >= this.directEntry.length) {
            this.directEntry = (DownstreamComponent[]) Arrays.copyOf(this.directEntry, i);
        }
    }

    public void ensureDirectExitCapacity(int i) {
        if (i >= this.directExit.length) {
            this.directExit = (UpstreamComponent[]) Arrays.copyOf(this.directExit, i);
        }
    }

    public void ensureInputCapacity(int i) {
        if (i >= this.input.length) {
            this.input = (InputLayer[]) Arrays.copyOf(this.input, i);
        }
    }

    public void ensureSubnetCapacity(int i) {
        if (i >= this.subnet.length) {
            this.subnet = (Network[]) Arrays.copyOf(this.subnet, i);
        }
    }

    public void ensureTargetCapacity(int i) {
        if (i >= this.target.length) {
            this.target = (TargetLayer[]) Arrays.copyOf(this.target, i);
        }
    }

    public void ensureTrainCapacity(int i) {
        if (i >= this.train.length) {
            this.train = (Component[]) Arrays.copyOf(this.train, i);
        }
    }

    public void ensureWeightedEntryCapacity(int i) {
        if (i >= this.weightedEntry.length) {
            this.weightedEntry = (WeightedCompound[]) Arrays.copyOf(this.weightedEntry, i);
        }
    }

    public Component getActivate(int i) {
        return this.activate[i];
    }

    public int getActivateSize() {
        return this.activate.length;
    }

    public Component getComponent(int i) {
        return this.all[i];
    }

    public Component getComponentByName(String str) {
        for (Component component : this.all) {
            if (component.getName().equals(str)) {
                return component;
            }
        }
        return null;
    }

    public Component[] getComponents() {
        return this.all;
    }

    public UpstreamComponent getExitPoint() {
        return getExitPoint(0);
    }

    public UpstreamComponent getExitPoint(int i) {
        return this.directExit[i];
    }

    public InputLayer getInputLayer() {
        return getInputLayer(0);
    }

    public InputLayer getInputLayer(int i) {
        return this.input[i];
    }

    public InputLayer[] getInputLayers() {
        return (InputLayer[]) this.input.clone();
    }

    @Override // dmonner.xlbp.Component
    public String getName() {
        return this.name;
    }

    public int getNExitPoints() {
        return this.directExit.length;
    }

    public TargetLayer getTargetLayer() {
        return getTargetLayer(0);
    }

    public TargetLayer getTargetLayer(int i) {
        return this.target[i];
    }

    public TargetLayer[] getTargetLayers() {
        return (TargetLayer[]) this.target.clone();
    }

    public Component getTrain(int i) {
        return this.train[i];
    }

    public int getTrainSize() {
        return this.train.length;
    }

    @Override // dmonner.xlbp.Component
    public boolean isBuilt() {
        return this.built;
    }

    public int nInput() {
        return this.nInput;
    }

    public int nTarget() {
        return this.nTarget;
    }

    @Override // dmonner.xlbp.Component
    public int nWeights() {
        int i = 0;
        for (int i2 = 0; i2 < this.nAll; i2++) {
            i += this.all[i2].nWeights();
        }
        return i;
    }

    public int nWeightsDeep() {
        HashMap hashMap = new HashMap();
        LinkedList linkedList = new LinkedList();
        linkedList.add(this);
        while (!linkedList.isEmpty()) {
            Component component = (Component) linkedList.poll();
            if (component instanceof WeightedCompound) {
                WeightedCompound weightedCompound = (WeightedCompound) component;
                for (int i = 0; i < weightedCompound.nUpstreamWeights(); i++) {
                    linkedList.add(weightedCompound.getUpstreamWeights(i));
                }
            }
            if (component instanceof Compound) {
                for (Component component2 : ((Compound) component).getComponents()) {
                    linkedList.add(component2);
                }
            } else if (component instanceof Network) {
                for (Component component3 : ((Network) component).getComponents()) {
                    linkedList.add(component3);
                }
            } else if (component instanceof WeightBank) {
                hashMap.put(((WeightBank) component).getWeightInput(), Integer.valueOf(component.nWeights()));
            } else {
                if (!(component instanceof Layer)) {
                    throw new IllegalArgumentException("Unhandled subtype of Component: " + component);
                }
                hashMap.put((Layer) component, Integer.valueOf(component.nWeights()));
            }
        }
        int i2 = 0;
        Iterator it = hashMap.values().iterator();
        while (it.hasNext()) {
            i2 += ((Integer) it.next()).intValue();
        }
        return i2;
    }

    @Override // dmonner.xlbp.Component
    public boolean optimize() {
        List asList = Arrays.asList(this.all);
        Iterator it = asList.iterator();
        while (it.hasNext()) {
            if (!((Component) it.next()).optimize()) {
                it.remove();
            }
        }
        this.all = (Component[]) asList.toArray(new Component[asList.size()]);
        this.nAll = this.all.length;
        List asList2 = Arrays.asList(this.activate);
        Iterator it2 = asList2.iterator();
        while (it2.hasNext()) {
            if (!asList.contains(it2.next())) {
                it2.remove();
            }
        }
        this.activate = (Component[]) asList2.toArray(new Component[asList2.size()]);
        this.nActivate = this.activate.length;
        List asList3 = Arrays.asList(this.train);
        Iterator it3 = asList3.iterator();
        while (it3.hasNext()) {
            if (!asList.contains(it3.next())) {
                it3.remove();
            }
        }
        this.train = (Component[]) asList3.toArray(new Component[asList3.size()]);
        this.nTrain = this.train.length;
        return true;
    }

    @Override // dmonner.xlbp.Component
    public void processBatch() {
        for (Component component : this.train) {
            component.processBatch();
        }
    }

    public void rebuild() {
        unbuild();
        build();
    }

    public void setActivateSize(int i) {
        this.nActivate = i;
    }

    public void setAllSize(int i) {
        this.nAll = i;
    }

    public void setInput(float[] fArr) {
        setInput(0, fArr);
    }

    public void setInput(int i, float[] fArr) {
        this.input[i].setInput(fArr);
    }

    public void setInputSize(int i) {
        this.nInput = i;
    }

    public void setTarget(float[] fArr) {
        setTarget(0, fArr);
    }

    public void setTarget(int i, float[] fArr) {
        this.target[i].setTarget(fArr);
    }

    public void setTargetSize(int i) {
        this.nTarget = i;
    }

    public void setTrainSize(int i) {
        this.nTrain = i;
    }

    @Override // dmonner.xlbp.Component
    public void setWeightInitializer(WeightInitializer weightInitializer) {
        this.win = weightInitializer;
        for (int i = 0; i < this.nAll; i++) {
            this.all[i].setWeightInitializer(weightInitializer);
        }
    }

    @Override // dmonner.xlbp.Component
    public void setWeightUpdaterType(WeightUpdaterType weightUpdaterType) {
        this.wut = weightUpdaterType;
        for (int i = 0; i < this.nAll; i++) {
            this.all[i].setWeightUpdaterType(weightUpdaterType);
        }
    }

    public int size() {
        return this.all.length;
    }

    @Override // dmonner.xlbp.Component
    public String toString() {
        return this.name;
    }

    @Override // dmonner.xlbp.Component
    public void toString(NetworkStringBuilder networkStringBuilder) {
        networkStringBuilder.appendln(this.name + ": ");
        for (int i = this.nAll - 1; i >= 0; i--) {
            this.all[i].toString(networkStringBuilder);
        }
    }

    @Override // dmonner.xlbp.Component
    public String toString(String str) {
        NetworkStringBuilder networkStringBuilder = new NetworkStringBuilder(str);
        toString(networkStringBuilder);
        return networkStringBuilder.toString();
    }

    @Override // dmonner.xlbp.Component
    public void unbuild() {
        this.built = false;
        for (int i = 0; i < this.nAll; i++) {
            this.all[i].unbuild();
        }
    }

    @Override // dmonner.xlbp.Component
    public void updateEligibilities() {
        for (int i = this.nActivate - 1; i >= 0; i--) {
            this.activate[i].updateEligibilities();
        }
    }

    @Override // dmonner.xlbp.Component
    public void updateResponsibilities() {
        for (int i = this.nTrain - 1; i >= 0; i--) {
            this.train[i].updateResponsibilities();
        }
    }

    @Override // dmonner.xlbp.Component
    public void updateWeights() {
        for (int i = this.nTrain - 1; i >= 0; i--) {
            this.train[i].updateWeights();
        }
    }
}
