/*
 * Decompiled with CFR 0.152.
 */
package dmonner.xlbp.example;

import dmonner.xlbp.Network;
import dmonner.xlbp.WeightUpdaterType;
import dmonner.xlbp.compound.InputCompound;
import dmonner.xlbp.compound.MemoryCellCompound;
import dmonner.xlbp.compound.XEntropyTargetCompound;
import dmonner.xlbp.stat.TestStat;
import dmonner.xlbp.trial.Step;
import dmonner.xlbp.trial.Trainer;
import dmonner.xlbp.trial.Trial;
import dmonner.xlbp.trial.TrialStreamAdapter;

public class SequentialParity
extends TrialStreamAdapter {
    private final int trialsPerEpoch;
    private final int trialLength;

    public static void main(String[] args) {
        int trialLength = 5;
        int trialsPerEpoch = 100;
        int epochs = 500;
        int memSize = 5;
        String memType = "IO";
        Network net = new Network("SeqParityNet");
        net.setWeightUpdaterType(WeightUpdaterType.basic(0.1f));
        SequentialParity task = new SequentialParity(net, 100, 5);
        InputCompound bit = new InputCompound("Bit", 1);
        MemoryCellCompound mem = new MemoryCellCompound("Mem", 5, "IO");
        XEntropyTargetCompound ans = new XEntropyTargetCompound("Ans", 1);
        ans.addUpstreamWeights(mem);
        mem.addUpstreamWeights(bit);
        net.add(bit);
        net.add(mem);
        net.add(ans);
        Trainer trainer = new Trainer(net, task){

            @Override
            public void postEpoch(int ep, TestStat stat) {
                System.out.println(ep + ":\t" + stat.getLastTrain().getStepStats().getFraction());
            }
        };
        System.out.println("Epoch\tAccuracy");
        TestStat result = trainer.run(500);
        System.out.println("Final Results:");
        System.out.println(result);
    }

    public SequentialParity(Network net, int trialsPerEpoch, int trialLength) {
        super("SequentialParity" + trialLength, net);
        this.trialsPerEpoch = trialsPerEpoch;
        this.trialLength = trialLength;
    }

    @Override
    public Trial nextTrainTrial() {
        Trial trial = new Trial(this.getMetaNetwork());
        boolean odd = false;
        for (int i = 0; i < this.trialLength; ++i) {
            boolean input = Math.random() < 0.5;
            Step step = trial.nextStep();
            step.addInput(new float[]{input ? 1.0f : 0.0f});
            step.addTarget(new float[]{(odd ^= input) ? 1.0f : 0.0f});
        }
        return trial;
    }

    @Override
    public int nTrainTrials() {
        return this.trialsPerEpoch;
    }
}

