/*
 * Decompiled with CFR 0.152.
 */
package dmonner.xlbp.trial;

import dmonner.xlbp.Network;
import dmonner.xlbp.trial.AbstractTrialStream;
import dmonner.xlbp.trial.Trial;
import dmonner.xlbp.util.ArrayQueue;
import dmonner.xlbp.util.MatrixTools;
import java.util.Random;

public abstract class HashTrialStream
extends AbstractTrialStream {
    private final int perFold;
    private final ArrayQueue<Trial>[] permutation;
    private final ArrayQueue<Trial> trainCache;
    private final ArrayQueue<Trial> testCache;
    private final ArrayQueue<Trial> validCache;

    public HashTrialStream(String name, Network net, int perFold, int cacheSize, int train, int test, int valid) {
        this(name, net, perFold, cacheSize, train, test, valid, new Random());
    }

    public HashTrialStream(String name, Network net, int perFold, int cacheSize, int train, int test, int valid, Random random) {
        this(name, net, perFold, cacheSize, HashTrialStream.makeSplitString(train, test, valid), random);
    }

    public HashTrialStream(String name, Network net, int perFold, int cacheSize, String split) {
        this(name, net, perFold, cacheSize, split, new Random());
    }

    public HashTrialStream(String name, Network net, int perFold, int cacheSize, String split, Random random) {
        super(name, net, split);
        int i;
        this.perFold = perFold;
        this.trainCache = new ArrayQueue(cacheSize);
        this.testCache = new ArrayQueue(cacheSize);
        this.validCache = new ArrayQueue(cacheSize);
        this.permutation = new ArrayQueue[this.nFolds()];
        int p = 0;
        for (i = 0; i < this.nTrainFolds(); ++i) {
            this.permutation[p++] = this.trainCache;
        }
        for (i = 0; i < this.nTestFolds(); ++i) {
            this.permutation[p++] = this.testCache;
        }
        for (i = 0; i < this.nValidationFolds(); ++i) {
            this.permutation[p++] = this.validCache;
        }
        MatrixTools.randomize(this.permutation, random);
    }

    private void consume() {
        Trial t = this.nextTrial();
        int mod = Math.abs(t.hashCode()) % this.nFolds();
        ArrayQueue<Trial> q = this.permutation[mod];
        if (!q.isFull()) {
            q.push(t);
        }
    }

    @Override
    public Trial nextTestTrial() {
        while (this.testCache.isEmpty()) {
            this.consume();
        }
        Trial next = this.testCache.pop();
        next.setKnown(false);
        return next;
    }

    @Override
    public Trial nextTrainTrial() {
        while (this.trainCache.isEmpty()) {
            this.consume();
        }
        Trial next = this.trainCache.pop();
        next.setKnown(true);
        return next;
    }

    public abstract Trial nextTrial();

    @Override
    public Trial nextValidationTrial() {
        if (this.nValidationFolds() == 0) {
            return null;
        }
        while (this.validCache.isEmpty()) {
            this.consume();
        }
        Trial next = this.validCache.pop();
        next.setKnown(false);
        return next;
    }

    @Override
    public int nTestTrials() {
        return this.nTestFolds() * this.perFold;
    }

    @Override
    public int nTrainTrials() {
        return this.nTrainFolds() * this.perFold;
    }

    @Override
    public int nValidationTrials() {
        return this.nValidationFolds() * this.perFold;
    }

    @Override
    public void setFold(int fold) {
    }
}

