/*
 * Decompiled with CFR 0.152.
 */
package dmonner.xlbp.trial;

import dmonner.xlbp.Network;
import dmonner.xlbp.stat.SetStat;
import dmonner.xlbp.stat.TestStat;
import dmonner.xlbp.stat.TrialStat;
import dmonner.xlbp.trial.NeverBreaker;
import dmonner.xlbp.trial.TrainingBreaker;
import dmonner.xlbp.trial.Trial;
import dmonner.xlbp.trial.TrialRecord;
import dmonner.xlbp.trial.TrialStream;
import dmonner.xlbp.util.CSVWriter;
import java.io.IOException;

public class Trainer {
    private final Network net;
    private final TrialStream stream;
    private TrialStat[][] evals;
    private TrialStat[][] bestEvals;
    private TrialRecord[][] records;
    private TrialRecord[][] bestRecords;
    private boolean keepEvaluations;
    private boolean keepRecords;
    private CSVWriter trainlog;
    private CSVWriter testlog;
    private CSVWriter validlog;
    private TrainingBreaker breaker;

    public Trainer(Network net, TrialStream stream) {
        this.net = net;
        this.stream = stream;
        this.keepEvaluations = false;
        this.keepRecords = false;
        this.breaker = new NeverBreaker();
    }

    private SetStat evaluateTest(int fold, int ep) {
        if (this.keepEvaluations) {
            this.evals[fold] = new TrialStat[this.stream.nTestTrials()];
        }
        String time = "F" + fold + "-" + ep;
        SetStat summary = new SetStat("Test" + time);
        for (int i = 0; i < this.stream.nTestTrials(); ++i) {
            Trial trial = this.stream.nextTestTrial();
            trial.setEvaluate(true);
            trial.setRecord(true);
            trial.run();
            summary.add(trial.getLastEvaluation());
            if (this.keepEvaluations) {
                this.evals[fold][i] = trial.getLastEvaluation();
            }
            this.postTestTrial(trial, trial.getLastEvaluation());
        }
        summary.analyze();
        return summary;
    }

    private SetStat evaluateTrain(int fold, int ep, boolean train) {
        if (this.keepEvaluations) {
            this.evals[fold] = new TrialStat[this.stream.nTrainTrials()];
        }
        String time = "F" + fold + "-" + ep;
        SetStat summary = new SetStat("Train" + time);
        for (int i = 0; i < this.stream.nTrainTrials(); ++i) {
            Trial trial = this.stream.nextTrainTrial();
            trial.setEvaluate(true);
            trial.setRecord(true);
            trial.run(train);
            summary.add(trial.getLastEvaluation());
            if (this.keepEvaluations) {
                this.evals[fold][i] = trial.getLastEvaluation();
            }
            this.postTrainTrial(trial, trial.getLastEvaluation());
        }
        this.net.processBatch();
        summary.analyze();
        return summary;
    }

    private SetStat evaluateValid(int fold, int ep) {
        if (this.keepEvaluations) {
            this.evals[fold] = new TrialStat[this.stream.nValidationTrials()];
        }
        String time = "F" + fold + "-" + ep;
        SetStat summary = new SetStat("Valid" + time);
        for (int i = 0; i < this.stream.nValidationTrials(); ++i) {
            Trial trial = this.stream.nextValidationTrial();
            trial.setEvaluate(true);
            trial.setRecord(true);
            trial.run();
            summary.add(trial.getLastEvaluation());
            if (this.keepEvaluations) {
                this.evals[fold][i] = trial.getLastEvaluation();
            }
            this.postValidationTrial(trial, trial.getLastEvaluation());
        }
        summary.analyze();
        return summary;
    }

    public TrialStat[] getEvaluations() {
        int n = 0;
        for (int i = 0; i < this.bestEvals.length; ++i) {
            n += this.bestEvals[i].length;
        }
        TrialStat[] all = new TrialStat[n];
        int t = 0;
        for (int i = 0; i < this.bestEvals.length; ++i) {
            TrialStat[] fold = this.bestEvals[i];
            for (int j = 0; j < fold.length; ++j) {
                all[t++] = fold[j];
            }
        }
        return all;
    }

    public Network getNetwork() {
        return this.net;
    }

    public TrialRecord[] getRecords() {
        int n = 0;
        for (int i = 0; i < this.bestRecords.length; ++i) {
            n += this.bestRecords[i].length;
        }
        TrialRecord[] all = new TrialRecord[n];
        int t = 0;
        for (int i = 0; i < this.bestRecords.length; ++i) {
            TrialRecord[] fold = this.bestRecords[i];
            for (int j = 0; j < fold.length; ++j) {
                all[t++] = fold[j];
            }
        }
        return all;
    }

    private void log(CSVWriter log, SetStat summary, int ep, int fold) {
        if (log != null) {
            try {
                if (ep == 0 && fold == 0) {
                    summary.saveHeader(log);
                }
                summary.saveData(log);
            }
            catch (IOException ex) {
                ex.printStackTrace();
            }
        }
    }

    public void postEpoch(int ep, TestStat stat) {
    }

    public void postFold(int fold, TestStat stat) {
    }

    public void postTestTrial(Trial trial, TrialStat stat) {
    }

    public void postTrainTrial(Trial trial, TrialStat stat) {
    }

    public void postValidationTrial(Trial trial, TrialStat stat) {
    }

    public void preEpoch(int ep) {
    }

    public void preFold(int fold) {
    }

    public void preTest(int fold) {
    }

    public void preTrain(int fold) {
    }

    public TestStat run(int maxEpochs) {
        TestStat total = new TestStat();
        for (int f = 0; f < this.stream.nFolds(); ++f) {
            this.net.rebuild();
            total.add(this.runFold(f, maxEpochs));
        }
        total.analyze();
        return total;
    }

    public TestStat runFold(int fold, int maxEpochs) {
        this.stream.setFold(fold);
        this.breaker.reset();
        this.preFold(fold);
        TestStat stat = new TestStat();
        this.log(this.trainlog, this.evaluateTrain(fold, 0, false), 0, fold);
        if (this.stream.nValidationFolds() > 0) {
            this.log(this.validlog, this.evaluateValid(fold, 0), 0, fold);
        }
        if (this.stream.nTestFolds() > 0) {
            this.log(this.testlog, this.evaluateTest(fold, 0), 0, fold);
        }
        for (int ep = 1; ep <= maxEpochs; ++ep) {
            boolean newBest;
            SetStat trainStat = null;
            SetStat validStat = null;
            SetStat testStat = null;
            this.preEpoch(ep);
            this.preTrain(fold);
            if (this.stream.nTrainFolds() > 0) {
                trainStat = this.evaluateTrain(fold, ep, true);
                this.log(this.trainlog, trainStat, ep, fold);
            }
            this.preTest(fold);
            if (this.stream.nValidationFolds() > 0) {
                validStat = this.evaluateValid(fold, ep);
                this.log(this.validlog, validStat, ep, fold);
            }
            if (this.stream.nTestFolds() > 0) {
                testStat = this.evaluateTest(fold, ep);
                this.log(this.testlog, testStat, ep, fold);
            }
            if (newBest = stat.add(trainStat, validStat, testStat)) {
                this.updateBest(fold);
            }
            this.postEpoch(ep, stat);
            if (this.breaker.isBreakTime(stat)) break;
        }
        this.postFold(fold, stat);
        return stat;
    }

    public void setBreaker(TrainingBreaker breaker) {
        this.breaker = breaker;
    }

    public void setKeepEvaluations(boolean keepEvaluations) {
        this.keepEvaluations = keepEvaluations;
        if (keepEvaluations) {
            this.evals = new TrialStat[this.stream.nFolds()][];
            this.bestEvals = new TrialStat[this.stream.nFolds()][];
        } else {
            this.evals = null;
            this.bestEvals = null;
        }
    }

    public void setKeepRecords(boolean keepRecords) {
        this.keepRecords = keepRecords;
        if (keepRecords) {
            this.records = new TrialRecord[this.stream.nFolds()][];
            this.bestRecords = new TrialRecord[this.stream.nFolds()][];
        } else {
            this.records = null;
            this.bestRecords = null;
        }
    }

    public void setTestLog(CSVWriter log) {
        this.testlog = log;
    }

    public void setTrainLog(CSVWriter log) {
        this.trainlog = log;
    }

    public void setValidationLog(CSVWriter log) {
        this.validlog = log;
    }

    private void updateBest(int fold) {
        if (this.keepEvaluations) {
            this.bestEvals[fold] = this.evals[fold];
        }
        if (this.keepRecords) {
            this.bestRecords[fold] = this.records[fold];
        }
    }
}

