/*
 * Decompiled with CFR 0.152.
 */
package dmonner.xlbp;

import dmonner.xlbp.Component;
import dmonner.xlbp.DownstreamComponent;
import dmonner.xlbp.Network;
import dmonner.xlbp.UniformWeightInitializer;
import dmonner.xlbp.UpstreamComponent;
import dmonner.xlbp.WeightInitializer;
import dmonner.xlbp.WeightUpdaterType;
import dmonner.xlbp.compound.Compound;
import dmonner.xlbp.compound.WeightedCompound;
import dmonner.xlbp.util.ListMap;
import dmonner.xlbp.util.ReflectionTools;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.Arrays;
import java.util.Collection;

public class NetworkConfigurator {
    private final String specSep = "\\s+";
    private final String[] pkgs = new String[]{"dmonner.xlbp", "dmonner.xlbp.compound", "dmonner.xlbp.layer"};
    private final String paramPrefix = "Param-";
    private final String componentPrefix = "Component-";
    private final String[] connectPrefixes = new String[]{"Connect-Direct-", "Connect-Weighted-", "Connect-WeightedAll-"};
    private final String networkPrefix = "Network-";
    private final String activatePrefix = "-";
    private final String trainPrefix = "+";
    private final String entryPrefix = ">";
    private final String exitPrefix = "<";
    private final ListMap<String, Object> params;
    private final ListMap<String, Component> components;
    private final ListMap<String, Component> specified;
    private final ListMap<String, Network> networks;
    private final Network meta;
    private final WeightUpdaterType wut;
    private final WeightInitializer win;

    public NetworkConfigurator(ListMap<String, Object> params) {
        this.params = params;
        this.components = new ListMap();
        this.specified = new ListMap();
        this.networks = new ListMap();
        for (String key : params.keyList()) {
            if (!key.startsWith("Component-")) continue;
            this.addComponent(key.substring("Component-".length()), params.get(key).toString());
        }
        for (String key : params.keyList()) {
            for (String prefix : this.connectPrefixes) {
                if (!key.startsWith(prefix) || !this.asBoolean(params.get(key))) continue;
                this.addConnection(prefix, key.substring(prefix.length()));
            }
        }
        float connectionProbability = ((Float)params.get("connectionProbability")).floatValue();
        String updaterType = params.get("updaterType").toString();
        this.win = new UniformWeightInitializer(connectionProbability);
        this.wut = WeightUpdaterType.fromString(updaterType, params);
        this.meta = new Network("Meta");
        this.meta.setWeightInitializer(this.win);
        this.meta.setWeightUpdaterType(this.wut);
        for (Component comp : this.specified.values()) {
            this.meta.add(comp);
        }
        for (String key : params.keyList()) {
            if (!key.startsWith("Network-")) continue;
            this.addNetwork(key.substring("Network-".length()), params.get(key).toString());
        }
    }

    public void addComponent(String name, String spec) {
        if (spec.trim().isEmpty()) {
            return;
        }
        String[] part = spec.split("\\s+");
        if (part.length < 1) {
            throw new IllegalArgumentException("Uninterpretable Component definition: " + spec);
        }
        String classname = part[0];
        Class<?> clazz = ReflectionTools.findClass(classname, this.pkgs);
        Object[] sign = new Class[part.length];
        Object[] args = new Object[part.length];
        sign[0] = String.class;
        args[0] = name;
        for (int i = 1; i < args.length; ++i) {
            args[i] = this.findParam(part[i]);
            sign[i] = ReflectionTools.unbox(args[i].getClass());
        }
        try {
            Constructor<?> constr = ReflectionTools.findConstructor(clazz, sign);
            Object inst = constr.newInstance(args);
            this.specified.put(name, (Component)inst);
            this.components.put(name, (Component)inst);
            if (inst instanceof Compound) {
                for (Component sub : ((Compound)inst).getComponents()) {
                    this.components.put(sub.getName(), sub);
                }
            }
        }
        catch (InvocationTargetException ex) {
            throw new IllegalArgumentException("Exception while running constructor for " + classname + " with signature " + Arrays.deepToString(sign), ex);
        }
        catch (IllegalAccessException ex) {
            ex.printStackTrace();
            throw new IllegalArgumentException("Inaccessible constructor for " + classname + " with signature " + Arrays.deepToString(sign), ex);
        }
        catch (InstantiationException ex) {
            ex.printStackTrace();
            throw new IllegalArgumentException("Uninstantiable constructor for " + classname + " with signature " + Arrays.deepToString(sign), ex);
        }
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public void addConnection(ConnectionType type, Component from, Component to) {
        if (type == ConnectionType.DIRECT) {
            if (!(to instanceof DownstreamComponent) || !(from instanceof UpstreamComponent)) throw new IllegalArgumentException("Cannot create direct connection between " + from + " and " + to);
            ((DownstreamComponent)to).addUpstream((UpstreamComponent)from);
            return;
        } else {
            if (type != ConnectionType.WEIGHTED) throw new IllegalArgumentException("Unhandled ConnectionType: " + (Object)((Object)type));
            if (!(to instanceof WeightedCompound) || !(from instanceof UpstreamComponent)) throw new IllegalArgumentException("Cannot create weighted connection between " + from + " and " + to);
            ((WeightedCompound)to).addUpstreamWeights((UpstreamComponent)from);
        }
    }

    public void addConnection(String prefix, String layers) {
        int expectedParts;
        ConnectionType type = null;
        if (prefix.equals(this.connectPrefixes[0])) {
            type = ConnectionType.DIRECT;
            expectedParts = 2;
        } else if (prefix.equals(this.connectPrefixes[1])) {
            type = ConnectionType.WEIGHTED;
            expectedParts = 2;
        } else {
            throw new IllegalArgumentException("Unhandled ConnectionType: " + prefix);
        }
        String[] parts = layers.split("-");
        if (parts.length != expectedParts) {
            throw new IllegalArgumentException("Wrong number of layers in connect parameter: " + layers);
        }
        if (!this.components.containsKey(parts[0])) {
            throw new IllegalArgumentException("No from-component named: " + parts[0]);
        }
        if (expectedParts > 1 && !this.components.containsKey(parts[1])) {
            throw new IllegalArgumentException("No to-component named: " + parts[1]);
        }
        Component from = null;
        Component to = null;
        from = this.components.get(parts[0]);
        if (expectedParts > 1) {
            to = this.components.get(parts[1]);
        }
        this.addConnection(type, from, to);
    }

    public void addNetwork(String name, String spec) {
        String[] parts;
        Network net = new Network(name);
        net.setWeightInitializer(this.win);
        net.setWeightUpdaterType(this.wut);
        this.networks.put(name, net);
        this.components.put(name, net);
        if (spec.trim().isEmpty()) {
            return;
        }
        for (String part : parts = spec.split("\\s+")) {
            String lastpart;
            boolean activate = true;
            boolean train = true;
            boolean entry = false;
            boolean exit = false;
            do {
                lastpart = part;
                if (part.startsWith("-")) {
                    train = false;
                    part = part.substring("-".length());
                    continue;
                }
                if (part.startsWith("+")) {
                    activate = false;
                    part = part.substring("+".length());
                    continue;
                }
                if (part.startsWith(">")) {
                    entry = true;
                    part = part.substring(">".length());
                    continue;
                }
                if (!part.startsWith("<")) continue;
                exit = true;
                part = part.substring("<".length());
            } while (!part.equals(lastpart));
            if (!this.components.containsKey(part)) {
                throw new IllegalArgumentException("Cannot find component to add to network " + net + ": " + part);
            }
            net.add(this.components.get(part), activate, train, entry, exit);
        }
    }

    private boolean asBoolean(Object obj) {
        if (obj instanceof Boolean) {
            return (Boolean)obj;
        }
        if (obj instanceof Integer) {
            return (Integer)obj != 0;
        }
        return false;
    }

    public void build() {
        this.meta.build();
    }

    private Object findParam(String key) {
        Object val = this.params.get(key);
        if (val == null) {
            val = this.params.get("Param-" + key);
        }
        if (val == null) {
            val = this.components.get(key);
        }
        if (val == null) {
            throw new IllegalArgumentException("Variable not found: " + key);
        }
        return val;
    }

    public Component getComponent(String name) {
        return this.components.get(name);
    }

    public Collection<Component> getComponents() {
        return this.components.values();
    }

    public Network getMetaNetwork() {
        return this.meta;
    }

    public Network getNetwork(String name) {
        return this.networks.get(name);
    }

    public Collection<Network> getNetworks() {
        return this.networks.values();
    }

    public boolean optimize() {
        return this.meta.optimize();
    }

    public String toString() {
        return this.components.toString();
    }

    public static enum ConnectionType {
        DIRECT,
        WEIGHTED;

    }
}

