/*
 * Decompiled with CFR 0.152.
 */
package dmonner.xlbp.trial;

import dmonner.xlbp.Input;
import dmonner.xlbp.Network;
import dmonner.xlbp.Target;
import dmonner.xlbp.layer.InputLayer;
import dmonner.xlbp.layer.Layer;
import dmonner.xlbp.layer.TargetLayer;
import dmonner.xlbp.stat.StepStat;
import dmonner.xlbp.trial.LayerCheck;
import dmonner.xlbp.trial.StepRecord;
import dmonner.xlbp.trial.Trial;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

public class Step {
    private final Trial trial;
    private Network net;
    private final Map<InputLayer, Input> inputs;
    private final Map<TargetLayer, Target> targets;
    private final Map<Layer, LayerCheck> checks;
    private final Set<Layer> recordLayers;
    private StepStat evaluation;
    private StepRecord recording;
    private boolean evaluate;
    private boolean record;

    public Step(Trial trial) {
        this.trial = trial;
        this.net = trial.getMetaNetwork();
        this.inputs = new HashMap<InputLayer, Input>();
        this.targets = new HashMap<TargetLayer, Target>();
        this.checks = new HashMap<Layer, LayerCheck>();
        this.recordLayers = new HashSet<Layer>();
        this.evaluate = true;
        this.record = true;
    }

    public void addCheck(Layer layer, float[] check) {
        this.addCheck(new LayerCheck(layer, check));
    }

    public void addCheck(LayerCheck check) {
        this.checks.put(check.getLayer(), check);
    }

    public void addInput(float[] input) {
        this.addInput(this.net.getInputLayer(), input);
    }

    public void addInput(Input input) {
        this.inputs.put(input.getLayer(), input);
    }

    public void addInput(InputLayer layer, float[] input) {
        this.addInput(new Input(layer, input));
    }

    public void addInput(int inputIndex, float[] input) {
        this.addInput(this.net.getInputLayer(inputIndex), input);
    }

    public void addRecordLayer(Layer record) {
        this.recordLayers.add(record);
    }

    public void addTarget(float[] target) {
        this.addTarget(this.net.getTargetLayer(), target);
    }

    public void addTarget(int targetIndex, float[] target) {
        this.addTarget(this.net.getTargetLayer(targetIndex), target);
    }

    public void addTarget(Target target) {
        this.targets.put(target.getLayer(), target);
    }

    public void addTarget(TargetLayer layer, float[] target) {
        this.addTarget(new Target(layer, target));
    }

    public void clear() {
        this.evaluation = null;
        this.recording = null;
    }

    public boolean equals(Object other) {
        if (super.equals(other)) {
            return true;
        }
        if (other instanceof Step) {
            Step that = (Step)other;
            return this.inputs.equals(that.inputs) && this.targets.equals(that.targets);
        }
        return false;
    }

    public StepStat evaluate() {
        this.evaluation = this.makeEvaluation();
        return this.evaluation;
    }

    public Collection<LayerCheck> getChecks() {
        return this.checks.values();
    }

    public Input getInput() {
        return this.getInput(this.net.getInputLayer());
    }

    public Input getInput(InputLayer layer) {
        return this.inputs.get(layer);
    }

    public Input getInput(int inputIndex) {
        return this.getInput(this.net.getInputLayer(inputIndex));
    }

    public Collection<Input> getInputs() {
        return this.inputs.values();
    }

    public StepStat getLastEvaluation() {
        return this.evaluation;
    }

    public StepRecord getLastRecording() {
        return this.recording;
    }

    public Network getNetwork() {
        return this.net;
    }

    public Set<Layer> getRecordLayers() {
        return this.recordLayers;
    }

    public Target getTarget() {
        return this.getTarget(this.net.getTargetLayer());
    }

    public Target getTarget(int targetIndex) {
        return this.getTarget(this.net.getTargetLayer(targetIndex));
    }

    public Target getTarget(TargetLayer layer) {
        return this.targets.get(layer);
    }

    public Collection<Target> getTargets() {
        return this.targets.values();
    }

    public Trial getTrial() {
        return this.trial;
    }

    public int hashCode() {
        return this.inputs.hashCode() + this.targets.hashCode();
    }

    public void initialize() {
        this.inputs.clear();
        this.targets.clear();
        this.checks.clear();
        this.recordLayers.clear();
        this.evaluation = null;
        this.recording = null;
    }

    protected StepStat makeEvaluation() {
        return new StepStat(this);
    }

    public StepRecord makeRecord() {
        return new StepRecord(this);
    }

    public int nEvals() {
        return this.checks.size();
    }

    public int nInputs() {
        return this.inputs.size();
    }

    public int nOutputs() {
        return this.net.nTarget();
    }

    public int nRecordLayers() {
        return this.recordLayers.size();
    }

    public int nTargets() {
        return this.targets.size();
    }

    public StepRecord record() {
        this.recording = this.makeRecord();
        return this.recording;
    }

    public void run() {
        this.run(false);
    }

    public void run(boolean train) {
        for (Input input : this.getInputs()) {
            input.apply();
        }
        if (train) {
            this.net.activateTrain();
            this.net.updateEligibilities();
            for (Target target : this.getTargets()) {
                target.apply();
            }
            if (this.nTargets() > 0) {
                this.net.updateResponsibilities();
                this.net.updateWeights();
            }
        } else {
            this.net.activateTest();
        }
        if (this.evaluate) {
            this.evaluate();
        }
        if (this.record) {
            this.record();
        }
    }

    public void setEvaluate(boolean evaluate) {
        this.evaluate = evaluate;
    }

    public void setNetwork(Network net) {
        this.net = net;
    }

    public void setRecord(boolean record) {
        this.record = record;
    }

    public void test() {
        this.run(false);
    }

    public String toString() {
        StringBuffer sb = new StringBuffer();
        for (Input input : this.getInputs()) {
            sb.append(input);
            sb.append("\n");
        }
        for (Target target : this.getTargets()) {
            sb.append(target);
            sb.append("\n");
        }
        return sb.toString();
    }

    public void train() {
        this.run(true);
    }
}

