/*
 * Decompiled with CFR 0.152.
 */
package dmonner.xlbp.trial;

import dmonner.xlbp.Network;
import dmonner.xlbp.Target;
import dmonner.xlbp.layer.TargetLayer;
import dmonner.xlbp.stat.BitStat;
import dmonner.xlbp.trial.AbstractTrialStream;
import dmonner.xlbp.trial.Step;
import dmonner.xlbp.trial.Trial;
import dmonner.xlbp.util.ArrayQueue;
import dmonner.xlbp.util.MatrixTools;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

public class TrialSet
extends AbstractTrialStream {
    private final Trial[][] folds;
    private Trial[] train;
    private Trial[] test;
    private Trial[] valid;
    private boolean balance;
    private final ArrayQueue<Trial> trainCache;
    private final ArrayQueue<Trial> testCache;
    private final ArrayQueue<Trial> validCache;
    private final Random rand;

    private static int[] parseFoldSplit(String foldSplit) {
        if (foldSplit.isEmpty()) {
            return new int[0];
        }
        String[] s = foldSplit.split("/");
        int[] f = new int[s.length];
        try {
            for (int i = 0; i < s.length; ++i) {
                f[i] = Integer.parseInt(s[i].trim());
            }
        }
        catch (NumberFormatException ex) {
            throw new IllegalArgumentException("Malformed foldSplit string; all entries must be numbers: " + foldSplit, ex);
        }
        return f;
    }

    public TrialSet(String name, Network net, Trial[] set, int train, int test, int valid) {
        this(name, net, set, train, test, valid, new Random());
    }

    public TrialSet(String name, Network net, Trial[] set, int train, int test, int valid, Random random) {
        this(name, net, set, TrialSet.makeSplitString(train, test, valid), random);
    }

    public TrialSet(String name, Network net, Trial[] set, int[] foldSizes, int test, int valid) {
        this(name, net, set, foldSizes, test, valid, new Random());
    }

    public TrialSet(String name, Network net, Trial[] set, int[] foldSizes, int test, int valid, Random random) {
        this(name, net, set, foldSizes, TrialSet.makeSplitString(foldSizes.length - test - valid, test, valid), random);
    }

    public TrialSet(String name, Network net, Trial[] set, int[] foldSizes, String split) {
        this(name, net, set, foldSizes, split, new Random());
    }

    public TrialSet(String name, Network net, Trial[] set, int[] foldSizes, String split, Random random) {
        super(name, net, split);
        this.rand = random;
        this.trainCache = new ArrayQueue();
        this.testCache = new ArrayQueue();
        this.validCache = new ArrayQueue();
        if (foldSizes.length > 0) {
            if (foldSizes.length != this.nFolds()) {
                throw new IllegalArgumentException("Number of foldSizes entries (" + foldSizes.length + ") does not match number of folds (" + this.nFolds() + ").");
            }
            this.folds = new Trial[this.nFolds()][];
            int prev = 0;
            for (int f = 0; f < this.nFolds(); ++f) {
                int n = foldSizes[f];
                this.folds[f] = new Trial[n];
                for (int i = 0; i < n; ++i) {
                    this.folds[f][i] = set[prev + i];
                }
                prev += n;
            }
        } else {
            double fraction = (double)set.length / (double)this.nFolds();
            MatrixTools.randomize(set, this.rand);
            this.folds = new Trial[this.nFolds()][];
            for (int f = 0; f < this.nFolds(); ++f) {
                int start = (int)((double)f * fraction);
                int end = (int)((double)(f + 1) * fraction);
                int n = end - start;
                this.folds[f] = new Trial[n];
                for (int i = 0; i < n; ++i) {
                    this.folds[f][i] = set[start + i];
                }
            }
        }
    }

    public TrialSet(String name, Network net, Trial[] set, String split) {
        this(name, net, set, "", split, new Random());
    }

    public TrialSet(String name, Network net, Trial[] set, String split, Random random) {
        this(name, net, set, "", split, random);
    }

    public TrialSet(String name, Network net, Trial[] set, String foldSplit, String split) {
        this(name, net, set, TrialSet.parseFoldSplit(foldSplit), split, new Random());
    }

    public TrialSet(String name, Network net, Trial[] set, String foldSplit, String split, Random random) {
        this(name, net, set, TrialSet.parseFoldSplit(foldSplit), split, random);
    }

    private void balanceTrainingSet() {
        Map<TargetLayer, List<Target>> byLayer = this.groupTargetsByLayer(this.train);
        for (Map.Entry<TargetLayer, List<Target>> entry : byLayer.entrySet()) {
            int i;
            TargetLayer layer = entry.getKey();
            List<Target> targets = entry.getValue();
            int n = layer.size() == 1 ? 2 : layer.size();
            ArrayList byBit = new ArrayList(n);
            for (int i2 = 0; i2 < n; ++i2) {
                byBit.add(new ArrayList());
            }
            for (Target target : targets) {
                ((List)byBit.get(this.getSetBitIndex(target))).add(target);
            }
            int[] sizes = new int[n];
            for (int i3 = 0; i3 < n; ++i3) {
                sizes[i3] = ((List)byBit.get(i3)).size();
            }
            int mindex = MatrixTools.argmin(sizes);
            float[] weights = new float[n];
            float min = sizes[mindex];
            for (i = 0; i < n; ++i) {
                weights[i] = min / (float)sizes[i];
            }
            for (i = 0; i < n; ++i) {
                for (Target target : (List)byBit.get(i)) {
                    target.setWeight(weights[i]);
                }
            }
            System.out.println(layer.getName() + ":");
            System.out.println(Arrays.toString(sizes));
            System.out.println(MatrixTools.toString(weights));
        }
    }

    private int getSetBitIndex(Target target) {
        float[] value = target.getValue();
        if (value.length == 1) {
            if (value[0] >= BitStat.MID) {
                return 1;
            }
            return 0;
        }
        int index = -1;
        for (int i = 0; i < value.length; ++i) {
            if (!(value[i] >= BitStat.MID)) continue;
            if (index >= 0) {
                throw new IllegalStateException("Cannot balance with more than one bit per Target.");
            }
            index = i;
        }
        if (index < 0) {
            throw new IllegalStateException("Cannot balance with no bits set in a Target.");
        }
        return index;
    }

    public Trial getTestTrial(int index) {
        return this.test[index];
    }

    public Trial getTrainTrial(int index) {
        return this.train[index];
    }

    public Trial getValidationTrial(int index) {
        return this.valid[index];
    }

    private Map<TargetLayer, List<Target>> groupTargetsByLayer(Trial[] set) {
        HashMap<TargetLayer, List<Target>> byLayer = new HashMap<TargetLayer, List<Target>>();
        for (TargetLayer layer : this.getMetaNetwork().getTargetLayers()) {
            byLayer.put(layer, new ArrayList());
        }
        for (Trial trial : set) {
            for (Step step : trial.getSteps()) {
                for (Target target : step.getTargets()) {
                    ((List)byLayer.get(target.getLayer())).add(target);
                }
            }
        }
        return byLayer;
    }

    @Override
    public Trial nextTestTrial() {
        if (this.testCache.isEmpty()) {
            this.testCache.fill((Trial[])this.test);
        }
        return this.testCache.pop();
    }

    @Override
    public Trial nextTrainTrial() {
        if (this.trainCache.isEmpty()) {
            this.trainCache.fill((Trial[])this.train);
        }
        return this.trainCache.pop();
    }

    @Override
    public Trial nextValidationTrial() {
        if (this.validCache.isEmpty()) {
            this.validCache.fill((Trial[])this.valid);
        }
        return this.validCache.pop();
    }

    @Override
    public int nTestTrials() {
        return this.test.length;
    }

    @Override
    public int nTrainTrials() {
        return this.train.length;
    }

    @Override
    public int nValidationTrials() {
        return this.valid.length;
    }

    private Trial[] select(int start, int num) {
        int n = 0;
        for (int i = 0; i < num; ++i) {
            n += this.folds[(start + i) % this.folds.length].length;
        }
        Trial[] tr = new Trial[n];
        int t = 0;
        for (int i = 0; i < num; ++i) {
            Trial[] fold = this.folds[(start + i) % this.folds.length];
            for (int j = 0; j < fold.length; ++j) {
                tr[t++] = fold[j];
            }
        }
        return tr;
    }

    private Trial[] selectTest(int fold) {
        Trial[] trials;
        for (Trial trial : trials = this.select((fold + this.nTrainFolds()) % this.folds.length, this.nTestFolds())) {
            trial.setKnown(false);
        }
        return trials;
    }

    private Trial[] selectTrain(int fold) {
        Trial[] trials;
        for (Trial trial : trials = this.select(fold, this.nTrainFolds())) {
            trial.setKnown(true);
        }
        return trials;
    }

    private Trial[] selectValidation(int fold) {
        Trial[] trials;
        for (Trial trial : trials = this.select((fold + this.nTrainFolds() + this.nTestFolds()) % this.folds.length, this.nValidationFolds())) {
            trial.setKnown(false);
        }
        return trials;
    }

    public void setBalance(boolean balance) {
        this.balance = balance;
    }

    @Override
    public void setFold(int fold) {
        this.train = this.selectTrain(fold);
        this.test = this.selectTest(fold);
        this.valid = this.selectValidation(fold);
        if (this.balance) {
            this.balanceTrainingSet();
        }
        Trial[][] arr$ = this.folds;
        int len$ = arr$.length;
        for (int i$ = 0; i$ < len$; ++i$) {
            Trial[] array;
            for (Trial trial : array = arr$[i$]) {
                trial.clear();
            }
        }
        MatrixTools.randomize(this.train, this.rand);
        this.trainCache.fill((Trial[])this.train);
        this.testCache.fill((Trial[])this.test);
        this.validCache.fill((Trial[])this.valid);
    }
}

