package dmonner.xlbp.example;

import dmonner.xlbp.Network;
import dmonner.xlbp.WeightUpdaterType;
import dmonner.xlbp.compound.InputCompound;
import dmonner.xlbp.compound.MemoryCellCompound;
import dmonner.xlbp.compound.XEntropyTargetCompound;
import dmonner.xlbp.stat.TestStat;
import dmonner.xlbp.trial.Step;
import dmonner.xlbp.trial.Trainer;
import dmonner.xlbp.trial.Trial;
import dmonner.xlbp.trial.TrialStreamAdapter;

/* loaded from: input_file:dmonner/xlbp/example/SequentialParity.class */
public class SequentialParity extends TrialStreamAdapter {
    private final int trialsPerEpoch;
    private final int trialLength;

    public static void main(String[] strArr) {
        Network network = new Network("SeqParityNet");
        network.setWeightUpdaterType(WeightUpdaterType.basic(0.1f));
        SequentialParity sequentialParity = new SequentialParity(network, 100, 5);
        InputCompound inputCompound = new InputCompound("Bit", 1);
        MemoryCellCompound memoryCellCompound = new MemoryCellCompound("Mem", 5, "IO");
        XEntropyTargetCompound xEntropyTargetCompound = new XEntropyTargetCompound("Ans", 1);
        xEntropyTargetCompound.addUpstreamWeights(memoryCellCompound);
        memoryCellCompound.addUpstreamWeights(inputCompound);
        network.add(inputCompound);
        network.add(memoryCellCompound);
        network.add(xEntropyTargetCompound);
        Trainer trainer = new Trainer(network, sequentialParity) { // from class: dmonner.xlbp.example.SequentialParity.1
            @Override // dmonner.xlbp.trial.Trainer
            public void postEpoch(int i, TestStat testStat) {
                System.out.println(i + ":\t" + testStat.getLastTrain().getStepStats().getFraction());
            }
        };
        System.out.println("Epoch\tAccuracy");
        TestStat run = trainer.run(500);
        System.out.println("Final Results:");
        System.out.println(run);
    }

    public SequentialParity(Network network, int i, int i2) {
        super("SequentialParity" + i2, network);
        this.trialsPerEpoch = i;
        this.trialLength = i2;
    }

    @Override // dmonner.xlbp.trial.TrialStreamAdapter, dmonner.xlbp.trial.TrialStream
    public Trial nextTrainTrial() {
        Trial trial = new Trial(getMetaNetwork());
        boolean z = false;
        for (int i = 0; i < this.trialLength; i++) {
            boolean z2 = Math.random() < 0.5d;
            z ^= z2;
            Step nextStep = trial.nextStep();
            float[] fArr = new float[1];
            fArr[0] = z2 ? 1.0f : 0.0f;
            nextStep.addInput(fArr);
            float[] fArr2 = new float[1];
            fArr2[0] = z ? 1.0f : 0.0f;
            nextStep.addTarget(fArr2);
        }
        return trial;
    }

    @Override // dmonner.xlbp.trial.TrialStreamAdapter, dmonner.xlbp.trial.TrialStream
    public int nTrainTrials() {
        return this.trialsPerEpoch;
    }
}
