package kc.mega.model;

import ags.utils.KdTree;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import jk.math.FastTrig;
import kc.mega.utils.MathUtils;
import kc.mega.wave.WaveWithFeatures;

/* loaded from: input_file:kc/mega/model/WaveKNN.class */
public class WaveKNN<T> {
    private final KdTree<T> tree;
    private final String[] features;
    private final double[][] params;
    private final double distanceScale;
    private final int maxNeighbors;
    private final double neighborhoodSizeDivider;
    private final List<double[][]> neuralNet = new ArrayList();
    private final List<double[]> activations = new ArrayList();

    /* loaded from: input_file:kc/mega/model/WaveKNN$Builder.class */
    public static class Builder<T> {
        private String[] features;
        private double[][] params;
        private double distanceScale;
        private double neighborhoodSizeDivider;
        private int maxNeighbors;
        private int maxTreeSize = 50000;
        private String neuralNet;

        public Builder<T> features(String[] strArr) {
            this.features = strArr;
            return this;
        }

        public Builder<T> params(double[][] dArr) {
            this.params = dArr;
            return this;
        }

        public Builder<T> distanceScale(double d) {
            this.distanceScale = d;
            return this;
        }

        public Builder<T> neighborhoodSizeDivider(double d) {
            this.neighborhoodSizeDivider = d;
            return this;
        }

        public Builder<T> maxNeighbors(int i) {
            this.maxNeighbors = i;
            return this;
        }

        public Builder<T> maxTreeSize(int i) {
            this.maxTreeSize = i;
            return this;
        }

        public Builder<T> nn(String str) {
            this.neuralNet = str;
            return this;
        }

        public WaveKNN<T> build() {
            return new WaveKNN<>(this);
        }
    }

    public WaveKNN(Builder<T> builder) {
        this.tree = new KdTree.Manhattan(((Builder) builder).features.length, Integer.valueOf(((Builder) builder).maxTreeSize));
        this.features = ((Builder) builder).features;
        this.params = ((Builder) builder).params;
        this.distanceScale = ((Builder) builder).distanceScale;
        this.maxNeighbors = ((Builder) builder).maxNeighbors;
        this.neighborhoodSizeDivider = ((Builder) builder).neighborhoodSizeDivider;
        if (((Builder) builder).neuralNet != null) {
            for (String str : ((Builder) builder).neuralNet.split("\n")) {
                String[] split = str.split(",");
                double[][] dArr = new double[split.length][split[0].split(" ").length];
                for (int i = 0; i < split.length; i++) {
                    String[] split2 = split[i].strip().split(" ");
                    for (int i2 = 0; i2 < split2.length; i2++) {
                        dArr[i][i2] = Double.valueOf(split2[i2]).doubleValue();
                    }
                }
                this.neuralNet.add(dArr);
                this.activations.add(new double[dArr[0].length]);
            }
        }
    }

    public void addPoint(WaveWithFeatures waveWithFeatures, T t) {
        this.tree.addPoint(embed(waveWithFeatures), t);
    }

    public List<KdTree.Entry<T>> getNeighbors(WaveWithFeatures waveWithFeatures) {
        return this.tree.nearestNeighbor(embed(waveWithFeatures), getNumNeighbors(), false);
    }

    public List<KdTree.Entry<T>> getNeighbors(WaveWithFeatures waveWithFeatures, int i) {
        return this.tree.nearestNeighbor(embed(waveWithFeatures), i, false);
    }

    public int getNumNeighbors() {
        return Math.min(this.maxNeighbors, Math.max(5, (int) (this.tree.size() / this.neighborhoodSizeDivider)));
    }

    public boolean isEmpty() {
        return this.tree.size() > 0;
    }

    public double[] embed(WaveWithFeatures waveWithFeatures) {
        return embed(getNormalizedFeatures(waveWithFeatures));
    }

    private double[] embed(Map<String, Double> map) {
        double d;
        double pow;
        double[] dArr = new double[this.features.length];
        for (int i = 0; i < this.features.length; i++) {
            double doubleValue = map.get(this.features[i]).doubleValue();
            int i2 = i;
            if (this.params[i].length == 1) {
                d = this.params[i][0];
                pow = doubleValue;
            } else {
                d = this.params[i][0];
                pow = Math.pow(1.0E-4d + this.params[i][1] + doubleValue, this.params[i][2]);
            }
            dArr[i2] = d * pow;
        }
        if (!this.neuralNet.isEmpty()) {
            double[] dArr2 = dArr;
            for (int i3 = 0; i3 < this.neuralNet.size(); i3++) {
                double[][] dArr3 = this.neuralNet.get(i3);
                double[] dArr4 = this.activations.get(i3);
                for (int i4 = 0; i4 < dArr2.length; i4++) {
                    for (int i5 = 0; i5 < dArr4.length; i5++) {
                        if (i4 == 0) {
                            dArr4[i5] = dArr3[i4][i5] * dArr2[i4];
                        } else {
                            int i6 = i5;
                            dArr4[i6] = dArr4[i6] + (dArr3[i4][i5] * dArr2[i4]);
                        }
                    }
                }
                if (i3 != this.neuralNet.size() - 1) {
                    for (int i7 = 0; i7 < dArr4.length; i7++) {
                        dArr4[i7] = Math.max(0.0d, dArr4[i7]);
                    }
                }
                dArr2 = dArr4;
            }
            for (int i8 = 0; i8 < dArr.length; i8++) {
                int i9 = i8;
                dArr[i9] = dArr[i9] * dArr2[i8];
            }
        }
        return dArr;
    }

    public double[] getWeights(List<KdTree.Entry<T>> list) {
        double[] dArr = new double[list.size()];
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < list.size(); i++) {
            double d2 = list.get(i).distance * this.distanceScale;
            dArr[i] = d2;
            d = Math.max(d2, d);
        }
        MathUtils.softmax(dArr, d);
        return dArr;
    }

    public static Map<String, Double> getNormalizedFeatures(WaveWithFeatures waveWithFeatures) {
        double d = waveWithFeatures.distance / waveWithFeatures.speed;
        double sin = waveWithFeatures.velocity * FastTrig.sin(waveWithFeatures.relativeHeading);
        double cos = waveWithFeatures.velocity * FastTrig.cos(waveWithFeatures.relativeHeading);
        double cos2 = waveWithFeatures.moveDirection * FastTrig.cos(waveWithFeatures.relativeHeading);
        HashMap hashMap = new HashMap();
        hashMap.put("virtuality", Double.valueOf(waveWithFeatures.virtuality / 5.0d));
        hashMap.put("power", Double.valueOf(waveWithFeatures.power / 3.0d));
        hashMap.put("bft", Double.valueOf(d / 100.0d));
        hashMap.put("accel", Double.valueOf(Math.max(2.0d + waveWithFeatures.accel, 0.0d) / 2.0d));
        hashMap.put("accelSign", Double.valueOf(Math.signum(waveWithFeatures.accel)));
        hashMap.put("latVel", Double.valueOf(Math.abs(sin) / 8.0d));
        hashMap.put("vel", Double.valueOf(Math.abs(waveWithFeatures.velocity) / 8.0d));
        hashMap.put("vel=8", Double.valueOf(Math.abs(waveWithFeatures.velocity) > 7.9d ? 1.0d : 0.0d));
        hashMap.put("advVel", Double.valueOf((cos + 16.0d) / 8.0d));
        hashMap.put("advDir", Double.valueOf((cos2 + 1.0d) / 2.0d));
        hashMap.put("vChangeTimer", Double.valueOf(Math.min(waveWithFeatures.vChangeTimer, 70) / d));
        hashMap.put("dirChangeTimer", Double.valueOf(Math.min(waveWithFeatures.dirChangeTimer, 70) / d));
        hashMap.put("decelTimer", Double.valueOf(Math.min(waveWithFeatures.decelTimer, 70) / d));
        hashMap.put("distanceLast10", Double.valueOf(waveWithFeatures.distanceLast10 / 80.0d));
        hashMap.put("distanceLast20", Double.valueOf(waveWithFeatures.distanceLast20 / 160.0d));
        hashMap.put("mirrorOffset", Double.valueOf(waveWithFeatures.mirrorOffset + 3.141592653589793d));
        hashMap.put("orbitalWallAhead", Double.valueOf(waveWithFeatures.orbitalWallAhead / 1.5d));
        hashMap.put("orbitalWallReverse", Double.valueOf(waveWithFeatures.orbitalWallReverse / 1.5d));
        hashMap.put("maeWallAhead", Double.valueOf(waveWithFeatures.maeWallAhead));
        hashMap.put("maeWallReverse", Double.valueOf(waveWithFeatures.maeWallReverse));
        hashMap.put("stickWallAhead", Double.valueOf(waveWithFeatures.stickWallAhead / 1.5707963267948966d));
        hashMap.put("stickWallReverse", Double.valueOf(waveWithFeatures.stickWallReverse / 1.5707963267948966d));
        hashMap.put("stickWallAhead2", Double.valueOf(waveWithFeatures.stickWallAhead2 / 1.5707963267948966d));
        hashMap.put("stickWallReverse2", Double.valueOf(waveWithFeatures.stickWallReverse2 / 1.5707963267948966d));
        hashMap.put("stickWallAhead=0", Double.valueOf(waveWithFeatures.stickWallAhead < 0.001d ? 1.0d : 0.0d));
        hashMap.put("stickWallReverse=0", Double.valueOf(waveWithFeatures.stickWallReverse < 0.001d ? 1.0d : 0.0d));
        hashMap.put("gameTime", Double.valueOf(waveWithFeatures.fireTime / 500.0d));
        hashMap.put("shotsFired", Double.valueOf(waveWithFeatures.shotsFired / 1000.0d));
        hashMap.put("currentGF", Double.valueOf((1.0d + waveWithFeatures.currentGF) / 2.0d));
        hashMap.put("didHit", Double.valueOf(waveWithFeatures.didHit ? 1.0d : 0.0d));
        hashMap.put("didCollide", Double.valueOf(waveWithFeatures.didCollide ? 1.0d : 0.0d));
        return hashMap;
    }
}
