/*
 * NeuralNet - Utility class for Robocode bots.
 * Copyright (C) 2002  Joachim Hofer
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 *
 * You can contact the author via email (qohnil@johoop.de) or write to
 * Joachim Hofer, Feldstr. 12, D-91052 Erlangen, Germany.
 */

package qohnil.neural;

import qohnil.neural.util.ActivationFunction;

import java.util.Random;
import java.util.Iterator;
import java.util.List;
import java.io.*;

import robocode.RobocodeFileOutputStream;

public class NeuralNet implements Serializable {
    int numInputs = 0;
    int numOutputs = 0;
    int numHidden1 = 0;
    int numHidden2 = 0;
    long epoch = 0;

    double momentum = 0.3;
    double learningRate = 0.8;

    double[] inputs = null;
    double[] outputs = null;
    double[] hidden1 = null;
    double[] hidden2 = null;
    double[] target = null;

    double[] deltasInput = null;
    double[] deltasOutput = null;
    double[] deltasHidden2 = null;
    double[] deltasHidden1 = null;

    double[][] weightsHidden1 = null;
    double[][] weightsHidden2 = null;
    double[][] weightsOutput = null;

    double[][] deltaWeightsHidden1 = null;
    double[][] deltaWeightsHidden2 = null;
    double[][] deltaWeightsOutput = null;

    static final Random random = new Random();

    public NeuralNet(int numInputs, int numHidden1, int numHidden2,
                     int numOutputs) {
        this.numInputs = numInputs;
        this.numOutputs = numOutputs;
        this.numHidden1 = numHidden1;
        this.numHidden2 = numHidden2;

        // allocate arrays
        weightsHidden1 = new double[numInputs + 1][numHidden1];
        deltaWeightsHidden1 = new double[numInputs + 1][numHidden1];

        if (numHidden2 > 0) {
            weightsHidden2 = new double[numHidden1 + 1][numHidden2];
            deltaWeightsHidden2 = new double[numHidden1 + 1][numHidden2];
            weightsOutput = new double[numHidden2 + 1][numOutputs];
            deltaWeightsOutput = new double[numHidden2 + 1][numOutputs];
        } else {
            weightsOutput = new double[numHidden1 + 1][numOutputs];
            deltaWeightsOutput = new double[numHidden1 + 1][numOutputs];
        }

        randomizeWeights();
        resetDeltas();
    }

    public void setInputs(double[] inputs) {
        this.inputs = inputs;
    }

    public long getEpoch() {
        return epoch;
    }

    public double[] getOutputs() {
        return outputs;
    }

    public void setTarget(double[] target) {
        this.target = target;
    }

    public void propagate() {
        outputs = propagate(inputs);
    }

    public double[] propagate(double[] inputs) {
        hidden1 = getPropagated(inputs, weightsHidden1);
        if (numHidden2 > 0) {
            hidden2 = getPropagated(hidden1, weightsHidden2);
            return getPropagated(hidden2, weightsOutput);
        } else {
            return getPropagated(hidden1, weightsOutput);
        }
    }

    public double getError() {
        return getError(outputs, target);
    }

    /**
     * Requires calls to setInputs() and setTarget() beforehand.
     */
    private void backpropagate() {

        if (numHidden2 > 0) {
            deltasOutput = getOutputDeltas();
            deltasHidden2 = getDeltas(hidden2, deltasOutput, weightsOutput);
            deltasHidden1 = getDeltas(hidden1, deltasHidden2, weightsHidden2);
            deltasInput = getDeltas(inputs, deltasHidden1, weightsHidden1);

        } else {
            deltasOutput = getOutputDeltas();
            deltasHidden1 = getDeltas(hidden1, deltasOutput, weightsOutput);
            deltasInput = getDeltas(inputs, deltasHidden1, weightsHidden1);
        }
    }

    public void adjustWeights(double learningRate, double momentum) {
        this.learningRate = learningRate;
        this.momentum = momentum;

        if (numHidden2 > 0) {
            adjustWeights(weightsOutput, deltaWeightsOutput, deltasOutput, hidden2);
            adjustWeights(weightsHidden2, deltaWeightsHidden2, deltasHidden2, hidden1);
            adjustWeights(weightsHidden1, deltaWeightsHidden1, deltasHidden1, inputs);
        } else {
            adjustWeights(weightsOutput, deltaWeightsOutput, deltasOutput, hidden1);
            adjustWeights(weightsHidden1, deltaWeightsHidden1, deltasHidden1, inputs);
        }
    }

    public double trainSingle(TrainingPattern pattern) {
        return trainSingle(pattern.getInputs(), pattern.getOutputs());
    }

    public double trainSingle(double[] inputs, double[] target) {
        setInputs(inputs);
        setTarget(target);
        propagate();
        backpropagate();

        return getError();
    }

    public double learnSingle(TrainingPattern pattern,
                              double learningRate, double momentum) {
        return learnSingle(pattern.getInputs(), pattern.getOutputs(),
                learningRate,  momentum);
    }

    public double learnSingle(double[] inputs, double[] target,
                            double learningRate, double momentum) {
        setInputs(inputs);
        setTarget(target);
        propagate();
        backpropagate();
        adjustWeights(learningRate, momentum);

        return getError();
    }

    public double learnSet(double[][] inputs, double[][] target,
                           double learningRate, double momentum) {
        epoch++;
        double totalError = 0.0;
        for (int i = 0; i < inputs.length; i++) {
            totalError += trainSingle(inputs[i], target[i]);
            adjustWeights(learningRate, momentum);
        }

        return totalError / inputs.length;
    }

    private void adjustWeights(double[][] weights,
                               double[][] previousCorrections,
                               double[] deltas,
                               double[] previousLayerOutputs) {
/*
        Logger.debug("adjustWeights(): prevCorr["
                + previousCorrections.length
                + "][" + previousCorrections[0].length + "] vs. ["
                + (previousLayerOutputs.length + 1) + "]["
                + deltas.length + "]");
*/
        for (int i = 0; i < previousLayerOutputs.length + 1; i++) {
            for (int j = 0; j < deltas.length; j++) {
                double correction;
                if (i == previousLayerOutputs.length) {
                    correction = deltas[j];
                } else {
                    correction = previousLayerOutputs[i] * deltas[j];
                }
                double adjustment = momentum * previousCorrections[i][j]
                        + learningRate * correction;
                weights[i][j] += adjustment;
                previousCorrections[i][j] = adjustment;
            }
        }
    }

    private double[] getDeltas(double[] thisLayerOutputs,
                               double[] nextLayerDeltas,
                               double[][] weights) {
        double[] deltas = new double[weights.length - 1];
        for (int i = 0; i < weights.length - 1; i++) {
            deltas[i] = 0.0;
            for (int j = 0; j < weights[i].length; j++) {
                deltas[i] += nextLayerDeltas[j] * weights[i][j];
            }
            deltas[i] *= ActivationFunction.getInstance().actDeriv(
                    thisLayerOutputs[i]);
        }

        return deltas;
    }

    private double[] getOutputDeltas() {
        double[] deltas = new double[numOutputs];
        for (int i = 0; i < numOutputs; i++) {
            deltas[i] += (target[i] - outputs[i])
                    * ActivationFunction.getInstance().actDeriv(outputs[i]);
        }

        return deltas;
    }

    private double[] getPropagated(double[] inputs, double[][] weights) {
        double[] outputs = new double[weights[0].length];
//        System.out.println("outputs.length = " + outputs.length);
//        System.out.println("inputs.length = " + inputs.length);
//        System.out.println("weights.length = " + weights.length);
        for (int i = 0; i < outputs.length; i++) {
            outputs[i] = 0.0;
            for (int j = 0; j < inputs.length; j++) {
                outputs[i] += inputs[j] * weights[j][i];
            }
            outputs[i] += weights[inputs.length][i];
            outputs[i] = ActivationFunction.getInstance().actFunc(outputs[i]);
        }
        return outputs;
    }

    private double getError(double[] actual, double[] target) {
        double error = 0.0;
        for (int i = 0; i < actual.length; i++) {
            double delta = target[i] - actual[i];
            error += delta * delta;
        }
        return error / actual.length;
    }

    private void randomizeWeights() {
        for (int i = 0; i < numInputs + 1; i++) {
            for (int j = 0; j < numHidden1; j++) {
                weightsHidden1[i][j] = random.nextDouble() * 2.0 - 1.0;
            }
        }
        if (numHidden2 > 0) {
            for (int i = 0; i < numHidden1 + 1; i++) {
                for (int j = 0; j < numHidden2; j++) {
                    weightsHidden2[i][j] = random.nextDouble() * 2.0 - 1.0;
                }
            }
            for (int i = 0; i < numHidden2 + 1; i++) {
                for (int j = 0; j < numOutputs; j++) {
                    weightsOutput[i][j] = random.nextDouble() * 2.0 - 1.0;
                }
            }
        } else {
            for (int i = 0; i < numHidden1 + 1; i++) {
                for (int j = 0; j < numOutputs; j++) {
                    weightsOutput[i][j] = random.nextDouble() * 2.0 - 1.0;
                }
            }
        }
    }

    private void resetDeltas() {
        for (int i = 0; i < numInputs + 1; i++) {
            for (int j = 0; j < numHidden1; j++) {
                deltaWeightsHidden1[i][j] = 0.0;
            }
        }
        if (numHidden2 > 0) {
            for (int i = 0; i < numHidden1 + 1; i++) {
                for (int j = 0; j < numHidden2; j++) {
                    deltaWeightsHidden2[i][j] = 0.0;
                }
            }
            for (int i = 0; i < numHidden2 + 1; i++) {
                for (int j = 0; j < numOutputs; j++) {
                    deltaWeightsOutput[i][j] = 0.0;
                }
            }
        } else {
            for (int i = 0; i < numHidden1 + 1; i++) {
                for (int j = 0; j < numOutputs; j++) {
                    deltaWeightsOutput[i][j] = 0.0;
                }
            }
        }
    }

    public void save(File file) throws IOException {
        ObjectOutputStream oos = new ObjectOutputStream(
                new RobocodeFileOutputStream(file));
        oos.writeObject(this);
        oos.close();
    }

    public static NeuralNet load(File file) throws IOException {
        NeuralNet nn = null;
        try {
            ObjectInputStream ois = new ObjectInputStream(
                    new FileInputStream(file));
            nn = (NeuralNet)ois.readObject();
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }

        return nn;
    }

    public long getEpochs() {
        return epoch;
    }

    public static void main(String[] args) {
        double[][] inps = {
            { 0.0, 0.0 }, { 0.0, 1.0 }, { 1.0, 0.0 }, { 1.0, 1.0 }
        };
        double[][] outs = {
            { 1.0 }, { 0.0 }, { 0.0 }, { 1.0 }
        };

        NeuralNet nn = new NeuralNet(2, 3, 0, 1);
        double error = Double.MAX_VALUE;
        while (error > 0.01) {
            error = nn.learnSet(inps, outs, 0.8, 0.3);
            System.out.println("err = " + error);
        }
        System.out.println("Learned: error = " + error + " in " + nn.getEpoch());

        for (int i = 0; i < inps.length; i++) {
            double[] out = nn.propagate(inps[i]);
            System.out.println(out[0]);
        }
    }
}
