package dmonner.xlbp;

import dmonner.xlbp.compound.Compound;
import dmonner.xlbp.compound.WeightedCompound;
import dmonner.xlbp.util.ListMap;
import dmonner.xlbp.util.ReflectionTools;
import java.lang.reflect.InvocationTargetException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;

/* loaded from: input_file:dmonner/xlbp/NetworkConfigurator.class */
public class NetworkConfigurator {
    private final ListMap<String, Object> params;
    private final Network meta;
    private final WeightUpdaterType wut;
    private final WeightInitializer win;
    private final String specSep = "\\s+";
    private final String[] pkgs = {"dmonner.xlbp", "dmonner.xlbp.compound", "dmonner.xlbp.layer"};
    private final String paramPrefix = "Param-";
    private final String componentPrefix = "Component-";
    private final String[] connectPrefixes = {"Connect-Direct-", "Connect-Weighted-", "Connect-WeightedAll-"};
    private final String networkPrefix = "Network-";
    private final String activatePrefix = "-";
    private final String trainPrefix = "+";
    private final String entryPrefix = ">";
    private final String exitPrefix = "<";
    private final ListMap<String, Component> components = new ListMap<>();
    private final ListMap<String, Component> specified = new ListMap<>();
    private final ListMap<String, Network> networks = new ListMap<>();

    /* loaded from: input_file:dmonner/xlbp/NetworkConfigurator$ConnectionType.class */
    public enum ConnectionType {
        DIRECT,
        WEIGHTED
    }

    public NetworkConfigurator(ListMap<String, Object> listMap) {
        this.params = listMap;
        for (String str : listMap.keyList()) {
            if (str.startsWith("Component-")) {
                addComponent(str.substring("Component-".length()), listMap.get(str).toString());
            }
        }
        for (String str2 : listMap.keyList()) {
            for (String str3 : this.connectPrefixes) {
                if (str2.startsWith(str3) && asBoolean(listMap.get(str2))) {
                    addConnection(str3, str2.substring(str3.length()));
                }
            }
        }
        float floatValue = ((Float) listMap.get("connectionProbability")).floatValue();
        String obj = listMap.get("updaterType").toString();
        this.win = new UniformWeightInitializer(floatValue);
        this.wut = WeightUpdaterType.fromString(obj, listMap);
        this.meta = new Network("Meta");
        this.meta.setWeightInitializer(this.win);
        this.meta.setWeightUpdaterType(this.wut);
        Iterator<Component> it = this.specified.values().iterator();
        while (it.hasNext()) {
            this.meta.add(it.next());
        }
        for (String str4 : listMap.keyList()) {
            if (str4.startsWith("Network-")) {
                addNetwork(str4.substring("Network-".length()), listMap.get(str4).toString());
            }
        }
    }

    public void addComponent(String str, String str2) {
        if (str2.trim().isEmpty()) {
            return;
        }
        String[] split = str2.split("\\s+");
        if (split.length < 1) {
            throw new IllegalArgumentException("Uninterpretable Component definition: " + str2);
        }
        String str3 = split[0];
        Class<?> findClass = ReflectionTools.findClass(str3, this.pkgs);
        Class[] clsArr = new Class[split.length];
        Object[] objArr = new Object[split.length];
        clsArr[0] = String.class;
        objArr[0] = str;
        for (int i = 1; i < objArr.length; i++) {
            objArr[i] = findParam(split[i]);
            clsArr[i] = ReflectionTools.unbox(objArr[i].getClass());
        }
        try {
            Object newInstance = ReflectionTools.findConstructor(findClass, clsArr).newInstance(objArr);
            this.specified.put(str, (Component) newInstance);
            this.components.put(str, (Component) newInstance);
            if (newInstance instanceof Compound) {
                for (Component component : ((Compound) newInstance).getComponents()) {
                    this.components.put(component.getName(), component);
                }
            }
        } catch (IllegalAccessException e) {
            e.printStackTrace();
            throw new IllegalArgumentException("Inaccessible constructor for " + str3 + " with signature " + Arrays.deepToString(clsArr), e);
        } catch (InstantiationException e2) {
            e2.printStackTrace();
            throw new IllegalArgumentException("Uninstantiable constructor for " + str3 + " with signature " + Arrays.deepToString(clsArr), e2);
        } catch (InvocationTargetException e3) {
            throw new IllegalArgumentException("Exception while running constructor for " + str3 + " with signature " + Arrays.deepToString(clsArr), e3);
        }
    }

    public void addConnection(ConnectionType connectionType, Component component, Component component2) {
        if (connectionType == ConnectionType.DIRECT) {
            if (!(component2 instanceof DownstreamComponent) || !(component instanceof UpstreamComponent)) {
                throw new IllegalArgumentException("Cannot create direct connection between " + component + " and " + component2);
            }
            ((DownstreamComponent) component2).addUpstream((UpstreamComponent) component);
            return;
        }
        if (connectionType != ConnectionType.WEIGHTED) {
            throw new IllegalArgumentException("Unhandled ConnectionType: " + connectionType);
        }
        if (!(component2 instanceof WeightedCompound) || !(component instanceof UpstreamComponent)) {
            throw new IllegalArgumentException("Cannot create weighted connection between " + component + " and " + component2);
        }
        ((WeightedCompound) component2).addUpstreamWeights((UpstreamComponent) component);
    }

    public void addConnection(String str, String str2) {
        ConnectionType connectionType;
        int i;
        if (str.equals(this.connectPrefixes[0])) {
            connectionType = ConnectionType.DIRECT;
            i = 2;
        } else {
            if (!str.equals(this.connectPrefixes[1])) {
                throw new IllegalArgumentException("Unhandled ConnectionType: " + str);
            }
            connectionType = ConnectionType.WEIGHTED;
            i = 2;
        }
        String[] split = str2.split("-");
        if (split.length != i) {
            throw new IllegalArgumentException("Wrong number of layers in connect parameter: " + str2);
        }
        if (!this.components.containsKey(split[0])) {
            throw new IllegalArgumentException("No from-component named: " + split[0]);
        }
        if (i > 1 && !this.components.containsKey(split[1])) {
            throw new IllegalArgumentException("No to-component named: " + split[1]);
        }
        Component component = null;
        Component component2 = this.components.get(split[0]);
        if (i > 1) {
            component = this.components.get(split[1]);
        }
        addConnection(connectionType, component2, component);
    }

    public void addNetwork(String str, String str2) {
        String str3;
        Network network = new Network(str);
        network.setWeightInitializer(this.win);
        network.setWeightUpdaterType(this.wut);
        this.networks.put(str, network);
        this.components.put(str, network);
        if (str2.trim().isEmpty()) {
            return;
        }
        for (String str4 : str2.split("\\s+")) {
            boolean z = true;
            boolean z2 = true;
            boolean z3 = false;
            boolean z4 = false;
            do {
                str3 = str4;
                if (str4.startsWith("-")) {
                    z2 = false;
                    str4 = str4.substring("-".length());
                } else if (str4.startsWith("+")) {
                    z = false;
                    str4 = str4.substring("+".length());
                } else if (str4.startsWith(">")) {
                    z3 = true;
                    str4 = str4.substring(">".length());
                } else if (str4.startsWith("<")) {
                    z4 = true;
                    str4 = str4.substring("<".length());
                }
            } while (!str4.equals(str3));
            if (!this.components.containsKey(str4)) {
                throw new IllegalArgumentException("Cannot find component to add to network " + network + ": " + str4);
            }
            network.add(this.components.get(str4), z, z2, z3, z4);
        }
    }

    private boolean asBoolean(Object obj) {
        return obj instanceof Boolean ? ((Boolean) obj).booleanValue() : (obj instanceof Integer) && ((Integer) obj).intValue() != 0;
    }

    public void build() {
        this.meta.build();
    }

    private Object findParam(String str) {
        Object obj = this.params.get(str);
        if (obj == null) {
            obj = this.params.get("Param-" + str);
        }
        if (obj == null) {
            obj = this.components.get(str);
        }
        if (obj == null) {
            throw new IllegalArgumentException("Variable not found: " + str);
        }
        return obj;
    }

    public Component getComponent(String str) {
        return this.components.get(str);
    }

    public Collection<Component> getComponents() {
        return this.components.values();
    }

    public Network getMetaNetwork() {
        return this.meta;
    }

    public Network getNetwork(String str) {
        return this.networks.get(str);
    }

    public Collection<Network> getNetworks() {
        return this.networks.values();
    }

    public boolean optimize() {
        return this.meta.optimize();
    }

    public String toString() {
        return this.components.toString();
    }
}
