package dsekercioglu.mega.megaCore;

import dsekercioglu.mega.megaCore.jk.kdTree.KDTree;
import dsekercioglu.mega.megaCore.jk.kdTree.KDTree.SearchResult;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.ListIterator;

public class KNNPredictor extends CorePredictor {

    public KDTree predictor;
    private int BINS;
    private int K;
    private double[] WEIGHTS;
    private boolean WBD;
    private KDeterminationAlgorithm ALGORITHM;

    public void setup(double[] weights, int k, int bins, boolean weightByDistance, KDeterminationAlgorithm algorithm) {
        predictor = new KDTree.Euclidean(weights.length);
        BINS = bins;
        WEIGHTS = weights;
        WBD = weightByDistance;
        ALGORITHM = algorithm;
        this.K = k;
    }

    public void setup(double[] weights, int k, int bins, KDeterminationAlgorithm algorithm) {
        predictor = new KDTree.Euclidean(weights.length);
        BINS = bins;
        WEIGHTS = weights;
        WBD = true;
        ALGORITHM = algorithm;
        this.K = k;
    }

    @Override
    public double[] predictBins(double[] data) {
        double[] bins = new double[BINS];
        if (!(predictor.size() == 0)) {
            double[] dataClone = getWeightedData(data.clone());
            int predictionNum = Math.min(ALGORITHM.getK(predictor.size()), predictor.size());
            ArrayList<SearchResult<double[]>> guessFactors = predictor.nearestNeighbours(dataClone, predictionNum);
            ListIterator i = guessFactors.listIterator();
            while (i.hasNext()) {
                SearchResult result = (SearchResult) i.next();
                double[] weightedGuessFactor = (double[]) result.payload;
                int guessFactor = (int) weightedGuessFactor[0];
                double weight = weightedGuessFactor[1];
                double distance = result.distance;
                double finalWeight = weight;
                if (WBD) {
                    finalWeight /= (distance + 1) * predictionNum;
                }
                //double finalWeight = weight / ((takeSquare ? distance * distance : distance) + 1);//no Normalisation =(
                bins[guessFactor] += finalWeight;
            }
        }
        return bins;
    }

    @Override
    public void addData(double[] data, int bin, double weight) {
        predictor.addPoint(getWeightedData(data.clone()), new double[]{bin, weight});
    }

    private double[] getWeightedData(double[] data) {
        for (int i = 0; i < data.length; i++) {
            data[i] *= WEIGHTS[i];
        }
        return data;
    }

    public KDeterminationAlgorithm sqrtSize(int maxK, double multiplier) {
        return new SqrtSize(maxK, multiplier);
    }

    public KDeterminationAlgorithm divisionK(int maxK, double divisor) {
        return new DivisionK(maxK, divisor);
    }

    private abstract class KDeterminationAlgorithm {

        final int MAX_K;

        public KDeterminationAlgorithm(int maxK) {
            MAX_K = maxK;
        }

        public abstract int getK(int size);
    }

    private class SqrtSize extends KDeterminationAlgorithm {

        private final double MULTIPLIER;

        public SqrtSize(int maxK, double multiplier) {
            super(maxK);
            MULTIPLIER = multiplier;
        }

        public int getK(int size) {
            return Math.max(Math.min(MAX_K, (int) (Math.sqrt(size) * MULTIPLIER)), 1);
        }
    }

    private class DivisionK extends KDeterminationAlgorithm {

        final double DIVISOR;

        public DivisionK(int maxK, double divisor) {
            super(maxK);
            DIVISOR = divisor;
        }

        public int getK(int size) {
            return Math.max(Math.min(MAX_K, (int) (size / DIVISOR)), 1);
        }
    }
}
