package dmonner.xlbp.trial;

import dmonner.xlbp.Input;
import dmonner.xlbp.Network;
import dmonner.xlbp.Target;
import dmonner.xlbp.layer.InputLayer;
import dmonner.xlbp.layer.Layer;
import dmonner.xlbp.layer.TargetLayer;
import dmonner.xlbp.stat.StepStat;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:dmonner/xlbp/trial/Step.class */
public class Step {
    private final Trial trial;
    private Network net;
    private StepStat evaluation;
    private StepRecord recording;
    private final Map<InputLayer, Input> inputs = new HashMap();
    private final Map<TargetLayer, Target> targets = new HashMap();
    private final Map<Layer, LayerCheck> checks = new HashMap();
    private final Set<Layer> recordLayers = new HashSet();
    private boolean evaluate = true;
    private boolean record = true;

    public Step(Trial trial) {
        this.trial = trial;
        this.net = trial.getMetaNetwork();
    }

    public void addCheck(Layer layer, float[] fArr) {
        addCheck(new LayerCheck(layer, fArr));
    }

    public void addCheck(LayerCheck layerCheck) {
        this.checks.put(layerCheck.getLayer(), layerCheck);
    }

    public void addInput(float[] fArr) {
        addInput(this.net.getInputLayer(), fArr);
    }

    public void addInput(Input input) {
        this.inputs.put(input.getLayer(), input);
    }

    public void addInput(InputLayer inputLayer, float[] fArr) {
        addInput(new Input(inputLayer, fArr));
    }

    public void addInput(int i, float[] fArr) {
        addInput(this.net.getInputLayer(i), fArr);
    }

    public void addRecordLayer(Layer layer) {
        this.recordLayers.add(layer);
    }

    public void addTarget(float[] fArr) {
        addTarget(this.net.getTargetLayer(), fArr);
    }

    public void addTarget(int i, float[] fArr) {
        addTarget(this.net.getTargetLayer(i), fArr);
    }

    public void addTarget(Target target) {
        this.targets.put(target.getLayer(), target);
    }

    public void addTarget(TargetLayer targetLayer, float[] fArr) {
        addTarget(new Target(targetLayer, fArr));
    }

    public void clear() {
        this.evaluation = null;
        this.recording = null;
    }

    public boolean equals(Object obj) {
        if (super.equals(obj)) {
            return true;
        }
        if (!(obj instanceof Step)) {
            return false;
        }
        Step step = (Step) obj;
        return this.inputs.equals(step.inputs) && this.targets.equals(step.targets);
    }

    public StepStat evaluate() {
        this.evaluation = makeEvaluation();
        return this.evaluation;
    }

    public Collection<LayerCheck> getChecks() {
        return this.checks.values();
    }

    public Input getInput() {
        return getInput(this.net.getInputLayer());
    }

    public Input getInput(InputLayer inputLayer) {
        return this.inputs.get(inputLayer);
    }

    public Input getInput(int i) {
        return getInput(this.net.getInputLayer(i));
    }

    public Collection<Input> getInputs() {
        return this.inputs.values();
    }

    public StepStat getLastEvaluation() {
        return this.evaluation;
    }

    public StepRecord getLastRecording() {
        return this.recording;
    }

    public Network getNetwork() {
        return this.net;
    }

    public Set<Layer> getRecordLayers() {
        return this.recordLayers;
    }

    public Target getTarget() {
        return getTarget(this.net.getTargetLayer());
    }

    public Target getTarget(int i) {
        return getTarget(this.net.getTargetLayer(i));
    }

    public Target getTarget(TargetLayer targetLayer) {
        return this.targets.get(targetLayer);
    }

    public Collection<Target> getTargets() {
        return this.targets.values();
    }

    public Trial getTrial() {
        return this.trial;
    }

    public int hashCode() {
        return this.inputs.hashCode() + this.targets.hashCode();
    }

    public void initialize() {
        this.inputs.clear();
        this.targets.clear();
        this.checks.clear();
        this.recordLayers.clear();
        this.evaluation = null;
        this.recording = null;
    }

    protected StepStat makeEvaluation() {
        return new StepStat(this);
    }

    public StepRecord makeRecord() {
        return new StepRecord(this);
    }

    public int nEvals() {
        return this.checks.size();
    }

    public int nInputs() {
        return this.inputs.size();
    }

    public int nOutputs() {
        return this.net.nTarget();
    }

    public int nRecordLayers() {
        return this.recordLayers.size();
    }

    public int nTargets() {
        return this.targets.size();
    }

    public StepRecord record() {
        this.recording = makeRecord();
        return this.recording;
    }

    public void run() {
        run(false);
    }

    public void run(boolean z) {
        Iterator<Input> it = getInputs().iterator();
        while (it.hasNext()) {
            it.next().apply();
        }
        if (z) {
            this.net.activateTrain();
            this.net.updateEligibilities();
            Iterator<Target> it2 = getTargets().iterator();
            while (it2.hasNext()) {
                it2.next().apply();
            }
            if (nTargets() > 0) {
                this.net.updateResponsibilities();
                this.net.updateWeights();
            }
        } else {
            this.net.activateTest();
        }
        if (this.evaluate) {
            evaluate();
        }
        if (this.record) {
            record();
        }
    }

    public void setEvaluate(boolean z) {
        this.evaluate = z;
    }

    public void setNetwork(Network network) {
        this.net = network;
    }

    public void setRecord(boolean z) {
        this.record = z;
    }

    public void test() {
        run(false);
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        Iterator<Input> it = getInputs().iterator();
        while (it.hasNext()) {
            stringBuffer.append(it.next());
            stringBuffer.append("\n");
        }
        Iterator<Target> it2 = getTargets().iterator();
        while (it2.hasNext()) {
            stringBuffer.append(it2.next());
            stringBuffer.append("\n");
        }
        return stringBuffer.toString();
    }

    public void train() {
        run(true);
    }
}
