/*
 * Decompiled with CFR 0.152.
 */
package dmonner.xlbp;

import dmonner.xlbp.Component;
import dmonner.xlbp.DownstreamComponent;
import dmonner.xlbp.NetworkCopier;
import dmonner.xlbp.NetworkStringBuilder;
import dmonner.xlbp.UniformWeightInitializer;
import dmonner.xlbp.UpstreamComponent;
import dmonner.xlbp.WeightInitializer;
import dmonner.xlbp.WeightUpdaterType;
import dmonner.xlbp.compound.Compound;
import dmonner.xlbp.compound.InputCompound;
import dmonner.xlbp.compound.MemoryCellCompound;
import dmonner.xlbp.compound.TargetCompound;
import dmonner.xlbp.compound.WeightBank;
import dmonner.xlbp.compound.WeightedCompound;
import dmonner.xlbp.compound.XEntropyTargetCompound;
import dmonner.xlbp.layer.InputLayer;
import dmonner.xlbp.layer.Layer;
import dmonner.xlbp.layer.TargetLayer;
import java.util.Arrays;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public class Network
implements Component {
    private static final long serialVersionUID = 1L;
    private final String name;
    private Component[] all;
    private Component[] activate;
    private Component[] train;
    private DownstreamComponent[] directEntry;
    private UpstreamComponent[] directExit;
    private WeightedCompound[] weightedEntry;
    private InputLayer[] input;
    private TargetLayer[] target;
    private Network[] subnet;
    private int nAll;
    private int nActivate;
    private int nTrain;
    private int nInput;
    private int nTarget;
    private int nSubnet;
    private WeightInitializer win;
    private WeightUpdaterType wut;
    private boolean built;

    public static void main(String[] args) {
        String mctype = "IFOPL";
        int insize = 300;
        int hidsize = 300;
        int outsize = 300;
        int trials = 1000;
        int repeats = 10;
        InputCompound in = new InputCompound("Input", 300);
        MemoryCellCompound mc = new MemoryCellCompound("Hidden", 300, "IFOPL");
        XEntropyTargetCompound out = new XEntropyTargetCompound("Output", 300);
        out.addUpstreamWeights(mc);
        mc.addUpstreamWeights(in);
        Network net = new Network("TheNet");
        net.setWeightUpdaterType(WeightUpdaterType.basic(0.1f));
        net.setWeightInitializer(new UniformWeightInitializer(1.0f, -0.1f, 0.1f));
        net.add(in);
        net.add(mc);
        net.add(out);
        net.optimize();
        net.build();
        System.out.println("Total weights: " + net.nWeights() + "\n");
        float[] input = new float[300];
        float[] target = new float[300];
        System.out.println("Each row represents training the network for 1000 trials.");
        System.out.println("Each entry specifies the time required in milliseconds.\n");
        System.out.println("clear\tactiv\tresp\tweights\t/ total");
        for (int r = 0; r < 10; ++r) {
            net.clear();
            long activate = 0L;
            long resp = 0L;
            long learn = 0L;
            long clear = 0L;
            Date start = new Date();
            for (int t = 0; t < 1000; ++t) {
                Date d0 = new Date();
                Date d1 = new Date();
                for (int i = 0; i < 300; ++i) {
                    input[i] = (float)Math.random();
                }
                in.setInput(input);
                net.activateTrain();
                net.updateEligibilities();
                Date d2 = new Date();
                for (int i = 0; i < 300; ++i) {
                    target[i] = (float)Math.random();
                }
                out.setTarget(target);
                net.updateResponsibilities();
                Date d3 = new Date();
                net.updateWeights();
                Date d4 = new Date();
                clear += d1.getTime() - d0.getTime();
                activate += d2.getTime() - d1.getTime();
                resp += d3.getTime() - d2.getTime();
                learn += d4.getTime() - d3.getTime();
            }
            Date end = new Date();
            System.out.println(clear + "\t" + activate + "\t" + resp + "\t" + learn + "\t/ " + (end.getTime() - start.getTime()));
        }
    }

    public Network(Network that, NetworkCopier copier) {
        int i;
        this.name = copier.getCopyNameFrom(that);
        this.nAll = that.nAll;
        this.nActivate = that.nActivate;
        this.nTrain = that.nTrain;
        this.nInput = that.nInput;
        this.nTarget = that.nTarget;
        this.nSubnet = that.nSubnet;
        this.win = that.win;
        this.wut = that.wut;
        this.all = new Component[that.all.length];
        this.activate = new Component[that.activate.length];
        this.train = new Component[that.train.length];
        this.input = new InputLayer[that.input.length];
        this.target = new TargetLayer[that.target.length];
        this.subnet = new Network[that.subnet.length];
        this.directEntry = new DownstreamComponent[that.directEntry.length];
        this.directExit = new UpstreamComponent[that.directExit.length];
        this.weightedEntry = new WeightedCompound[that.weightedEntry.length];
        for (i = 0; i < that.all.length; ++i) {
            this.all[i] = copier.getCopyOf(that.all[i]);
        }
        for (i = 0; i < that.activate.length; ++i) {
            this.activate[i] = copier.getCopyOf(that.activate[i]);
        }
        for (i = 0; i < that.train.length; ++i) {
            this.train[i] = copier.getCopyOf(that.train[i]);
        }
        for (i = 0; i < that.input.length; ++i) {
            this.input[i] = copier.getCopyOf(that.input[i]);
        }
        for (i = 0; i < that.target.length; ++i) {
            this.target[i] = copier.getCopyOf(that.target[i]);
        }
        for (i = 0; i < that.subnet.length; ++i) {
            this.subnet[i] = copier.getCopyOf(that.subnet[i]);
        }
        for (i = 0; i < that.directEntry.length; ++i) {
            this.directEntry[i] = copier.getCopyOf(that.directEntry[i]);
        }
        for (i = 0; i < that.directExit.length; ++i) {
            this.directExit[i] = copier.getCopyOf(that.directExit[i]);
        }
        for (i = 0; i < that.weightedEntry.length; ++i) {
            this.weightedEntry[i] = copier.getCopyOf(that.weightedEntry[i]);
        }
    }

    public Network(String name) {
        this(name, 0, 0, 0, 0, 0, 0);
    }

    public Network(String name, int nAll, int nActivate, int nTrain, int nInputs, int nTargets, int nSubnet) {
        this.name = name;
        this.all = new Component[nAll];
        this.activate = new Component[nActivate];
        this.train = new Component[nTrain];
        this.input = new InputLayer[nInputs];
        this.target = new TargetLayer[nTargets];
        this.subnet = new Network[nSubnet];
        this.directEntry = new DownstreamComponent[0];
        this.directExit = new UpstreamComponent[0];
        this.weightedEntry = new WeightedCompound[0];
        this.win = new UniformWeightInitializer();
        this.wut = WeightUpdaterType.basic();
    }

    @Override
    public void activateTest() {
        for (int i = 0; i < this.nActivate; ++i) {
            this.activate[i].activateTest();
        }
    }

    @Override
    public void activateTrain() {
        int i;
        for (i = 0; i < this.nActivate; ++i) {
            this.activate[i].clearResponsibilities();
        }
        for (i = 0; i < this.nActivate; ++i) {
            this.activate[i].activateTrain();
        }
    }

    public void add(Component component) {
        this.add(component, true, true, false, false);
    }

    public void add(Component component, boolean activate, boolean train, boolean entry, boolean exit) {
        component.setWeightInitializer(this.win);
        component.setWeightUpdaterType(this.wut);
        if (component instanceof Network) {
            this.addSubnet((Network)component);
        }
        if (component instanceof InputCompound) {
            this.addInput(((InputCompound)component).getInputLayer());
        }
        if (component instanceof InputLayer) {
            this.addInput((InputLayer)component);
        }
        if (component instanceof TargetCompound) {
            this.addTarget(((TargetCompound)component).getTargetLayer());
        }
        if (component instanceof TargetLayer) {
            this.addTarget((TargetLayer)component);
        }
        if (activate) {
            this.addActivate(component);
        }
        if (train) {
            this.addTrain(component);
        }
        if (entry) {
            this.addEntry(component);
        }
        if (exit) {
            this.addExit(component);
        }
        this.addAll(component);
    }

    private void addActivate(Component component) {
        this.ensureActivateCapacity(this.nActivate + 1);
        this.activate[this.nActivate++] = component;
    }

    public void addActivateOnly(Component component) {
        this.add(component, true, false, false, false);
    }

    private void addAll(Component component) {
        this.ensureAllCapacity(this.nAll + 1);
        this.all[this.nAll++] = component;
    }

    private void addEntry(Component component) {
        if (component instanceof WeightedCompound) {
            this.ensureWeightedEntryCapacity(this.weightedEntry.length + 1);
            this.weightedEntry[this.weightedEntry.length - 1] = (WeightedCompound)component;
        }
        if (component instanceof DownstreamComponent) {
            this.ensureDirectEntryCapacity(this.directEntry.length + 1);
            this.directEntry[this.directEntry.length - 1] = (DownstreamComponent)component;
        }
    }

    private void addExit(Component component) {
        if (component instanceof UpstreamComponent) {
            this.ensureDirectExitCapacity(this.directExit.length + 1);
            this.directExit[this.directExit.length - 1] = (UpstreamComponent)component;
        }
    }

    private void addInput(InputLayer inLayer) {
        this.ensureInputCapacity(this.nInput + 1);
        this.input[this.nInput++] = inLayer;
    }

    private void addSubnet(Network sub) {
        this.ensureSubnetCapacity(this.nSubnet + 1);
        this.subnet[this.nSubnet++] = sub;
        for (InputLayer inputLayer : sub.input) {
            this.addInput(inputLayer);
        }
        for (Layer layer : sub.target) {
            this.addTarget((TargetLayer)layer);
        }
    }

    private void addTarget(TargetLayer tgtLayer) {
        this.ensureTargetCapacity(this.nTarget + 1);
        this.target[this.nTarget++] = tgtLayer;
    }

    private void addTrain(Component component) {
        this.ensureTrainCapacity(this.nTrain + 1);
        this.train[this.nTrain++] = component;
    }

    public void addTrainOnly(Component component) {
        this.add(component, false, true, false, false);
    }

    public void addUpstream(UpstreamComponent upstream) {
        for (DownstreamComponent entry : this.directEntry) {
            entry.addUpstream(upstream);
        }
    }

    public void addUpstream(UpstreamComponent upstream, boolean weighted) {
        if (weighted) {
            this.addUpstreamWeights(upstream);
        } else {
            this.addUpstream(upstream);
        }
    }

    public void addUpstreamWeights(UpstreamComponent upstream) {
        for (WeightedCompound entry : this.weightedEntry) {
            entry.addUpstreamWeights(upstream);
        }
    }

    @Override
    public void build() {
        if (!this.built) {
            for (int i = 0; i < this.nAll; ++i) {
                this.all[i].build();
            }
        }
        this.built = true;
    }

    @Override
    public void clear() {
        this.clearActivations();
        this.clearEligibilities();
        this.clearResponsibilities();
    }

    @Override
    public void clearActivations() {
        for (int i = 0; i < this.nAll; ++i) {
            this.all[i].clearActivations();
        }
    }

    @Override
    public void clearEligibilities() {
        for (int i = 0; i < this.nAll; ++i) {
            this.all[i].clearEligibilities();
        }
    }

    public void clearInputs() {
        for (int i = 0; i < this.nInput; ++i) {
            this.input[i].clear();
        }
        for (Network sub : this.subnet) {
            sub.clearInputs();
        }
    }

    @Override
    public void clearResponsibilities() {
        for (int i = 0; i < this.nAll; ++i) {
            this.all[i].clearResponsibilities();
        }
    }

    @Override
    public int compareTo(Component that) {
        return this.name.compareTo(that.getName());
    }

    @Override
    public Network copy(NetworkCopier copier) {
        return new Network(this, copier);
    }

    @Override
    public Network copy(String suffix) {
        return this.copy(suffix, false, false);
    }

    public Network copy(String suffix, boolean copyState, boolean copyWeights) {
        return this.copy("", suffix, copyState, copyWeights);
    }

    public Network copy(String prefix, String suffix, boolean copyState, boolean copyWeights) {
        NetworkCopier copier = new NetworkCopier(prefix, suffix, copyState, copyWeights);
        Network copy = this.copy(copier);
        copier.build();
        return copy;
    }

    @Override
    public void copyConnectivityFrom(Component comp, NetworkCopier copier) {
    }

    public void ensureActivateCapacity(int cActivate) {
        if (cActivate >= this.activate.length) {
            this.activate = Arrays.copyOf(this.activate, cActivate);
        }
    }

    public void ensureAllCapacity(int cAll) {
        if (cAll >= this.all.length) {
            this.all = Arrays.copyOf(this.all, cAll);
        }
    }

    public void ensureDirectEntryCapacity(int cEntry) {
        if (cEntry >= this.directEntry.length) {
            this.directEntry = Arrays.copyOf(this.directEntry, cEntry);
        }
    }

    public void ensureDirectExitCapacity(int cExit) {
        if (cExit >= this.directExit.length) {
            this.directExit = Arrays.copyOf(this.directExit, cExit);
        }
    }

    public void ensureInputCapacity(int cInputs) {
        if (cInputs >= this.input.length) {
            this.input = Arrays.copyOf(this.input, cInputs);
        }
    }

    public void ensureSubnetCapacity(int cSubnets) {
        if (cSubnets >= this.subnet.length) {
            this.subnet = Arrays.copyOf(this.subnet, cSubnets);
        }
    }

    public void ensureTargetCapacity(int cTargets) {
        if (cTargets >= this.target.length) {
            this.target = Arrays.copyOf(this.target, cTargets);
        }
    }

    public void ensureTrainCapacity(int cTrain) {
        if (cTrain >= this.train.length) {
            this.train = Arrays.copyOf(this.train, cTrain);
        }
    }

    public void ensureWeightedEntryCapacity(int cEntry) {
        if (cEntry >= this.weightedEntry.length) {
            this.weightedEntry = Arrays.copyOf(this.weightedEntry, cEntry);
        }
    }

    public Component getActivate(int index) {
        return this.activate[index];
    }

    public int getActivateSize() {
        return this.activate.length;
    }

    public Component getComponent(int index) {
        return this.all[index];
    }

    public Component getComponentByName(String name) {
        for (Component component : this.all) {
            if (!component.getName().equals(name)) continue;
            return component;
        }
        return null;
    }

    public Component[] getComponents() {
        return this.all;
    }

    public UpstreamComponent getExitPoint() {
        return this.getExitPoint(0);
    }

    public UpstreamComponent getExitPoint(int i) {
        return this.directExit[i];
    }

    public InputLayer getInputLayer() {
        return this.getInputLayer(0);
    }

    public InputLayer getInputLayer(int index) {
        return this.input[index];
    }

    public InputLayer[] getInputLayers() {
        return (InputLayer[])this.input.clone();
    }

    @Override
    public String getName() {
        return this.name;
    }

    public int getNExitPoints() {
        return this.directExit.length;
    }

    public TargetLayer getTargetLayer() {
        return this.getTargetLayer(0);
    }

    public TargetLayer getTargetLayer(int index) {
        return this.target[index];
    }

    public TargetLayer[] getTargetLayers() {
        return (TargetLayer[])this.target.clone();
    }

    public Component getTrain(int index) {
        return this.train[index];
    }

    public int getTrainSize() {
        return this.train.length;
    }

    @Override
    public boolean isBuilt() {
        return this.built;
    }

    public int nInput() {
        return this.nInput;
    }

    public int nTarget() {
        return this.nTarget;
    }

    @Override
    public int nWeights() {
        int sum = 0;
        for (int i = 0; i < this.nAll; ++i) {
            sum += this.all[i].nWeights();
        }
        return sum;
    }

    public int nWeightsDeep() {
        int i;
        HashMap<Layer, Integer> map = new HashMap<Layer, Integer>();
        LinkedList<Component> q = new LinkedList<Component>();
        q.add(this);
        while (!q.isEmpty()) {
            Component comp = (Component)q.poll();
            if (comp instanceof WeightedCompound) {
                WeightedCompound wcomp = (WeightedCompound)comp;
                for (i = 0; i < wcomp.nUpstreamWeights(); ++i) {
                    WeightBank bank = wcomp.getUpstreamWeights(i);
                    q.add(bank);
                }
            }
            if (comp instanceof Compound) {
                for (Component sub : ((Compound)comp).getComponents()) {
                    q.add(sub);
                }
                continue;
            }
            if (comp instanceof Network) {
                for (Component sub : ((Network)comp).getComponents()) {
                    q.add(sub);
                }
                continue;
            }
            if (comp instanceof WeightBank) {
                map.put(((WeightBank)comp).getWeightInput(), comp.nWeights());
                continue;
            }
            if (comp instanceof Layer) {
                map.put((Layer)comp, comp.nWeights());
                continue;
            }
            throw new IllegalArgumentException("Unhandled subtype of Component: " + comp);
        }
        int sum = 0;
        Iterator i$ = map.values().iterator();
        while (i$.hasNext()) {
            i = (Integer)i$.next();
            sum += i;
        }
        return sum;
    }

    @Override
    public boolean optimize() {
        List<Component> allList = Arrays.asList(this.all);
        Iterator<Component> allIt = allList.iterator();
        while (allIt.hasNext()) {
            if (allIt.next().optimize()) continue;
            allIt.remove();
        }
        this.all = allList.toArray(new Component[allList.size()]);
        this.nAll = this.all.length;
        List<Component> activateList = Arrays.asList(this.activate);
        Iterator<Component> activateIt = activateList.iterator();
        while (activateIt.hasNext()) {
            if (allList.contains(activateIt.next())) continue;
            activateIt.remove();
        }
        this.activate = activateList.toArray(new Component[activateList.size()]);
        this.nActivate = this.activate.length;
        List<Component> trainList = Arrays.asList(this.train);
        Iterator<Component> trainIt = trainList.iterator();
        while (trainIt.hasNext()) {
            if (allList.contains(trainIt.next())) continue;
            trainIt.remove();
        }
        this.train = trainList.toArray(new Component[trainList.size()]);
        this.nTrain = this.train.length;
        return true;
    }

    @Override
    public void processBatch() {
        for (Component comp : this.train) {
            comp.processBatch();
        }
    }

    public void rebuild() {
        this.unbuild();
        this.build();
    }

    public void setActivateSize(int nActivate) {
        this.nActivate = nActivate;
    }

    public void setAllSize(int nAll) {
        this.nAll = nAll;
    }

    public void setInput(float[] input) {
        this.setInput(0, input);
    }

    public void setInput(int index, float[] activations) {
        this.input[index].setInput(activations);
    }

    public void setInputSize(int nInputs) {
        this.nInput = nInputs;
    }

    public void setTarget(float[] target) {
        this.setTarget(0, target);
    }

    public void setTarget(int index, float[] activations) {
        this.target[index].setTarget(activations);
    }

    public void setTargetSize(int nTargets) {
        this.nTarget = nTargets;
    }

    public void setTrainSize(int nTrain) {
        this.nTrain = nTrain;
    }

    @Override
    public void setWeightInitializer(WeightInitializer win) {
        this.win = win;
        for (int i = 0; i < this.nAll; ++i) {
            this.all[i].setWeightInitializer(win);
        }
    }

    @Override
    public void setWeightUpdaterType(WeightUpdaterType wut) {
        this.wut = wut;
        for (int i = 0; i < this.nAll; ++i) {
            this.all[i].setWeightUpdaterType(wut);
        }
    }

    public int size() {
        return this.all.length;
    }

    @Override
    public String toString() {
        return this.name;
    }

    @Override
    public void toString(NetworkStringBuilder sb) {
        sb.appendln(this.name + ": ");
        for (int i = this.nAll - 1; i >= 0; --i) {
            this.all[i].toString(sb);
        }
    }

    @Override
    public String toString(String show) {
        NetworkStringBuilder sb = new NetworkStringBuilder(show);
        this.toString(sb);
        return sb.toString();
    }

    @Override
    public void unbuild() {
        this.built = false;
        for (int i = 0; i < this.nAll; ++i) {
            this.all[i].unbuild();
        }
    }

    @Override
    public void updateEligibilities() {
        for (int i = this.nActivate - 1; i >= 0; --i) {
            this.activate[i].updateEligibilities();
        }
    }

    @Override
    public void updateResponsibilities() {
        for (int i = this.nTrain - 1; i >= 0; --i) {
            this.train[i].updateResponsibilities();
        }
    }

    @Override
    public void updateWeights() {
        for (int i = this.nTrain - 1; i >= 0; --i) {
            this.train[i].updateWeights();
        }
    }
}

