package dsekercioglu.neural.core.neurox.nn;

import dsekercioglu.neural.core.neurox.nn.activationfunctions.ActivationFunction;
import dsekercioglu.neural.core.neurox.nn.activationfunctions.Softmax;
import dsekercioglu.neural.core.neurox.nn.lossfunctions.LossFunction;
import java.util.ArrayList;

/* loaded from: input_file:dsekercioglu/neural/core/neurox/nn/NeuralNetwork.class */
public class NeuralNetwork {
    final boolean BIAS;
    private final int INPUT_NUM;
    final LossFunction LOSS_FUNCTION;
    double learningRate = 0.1d;
    double momentum = 0.9d;
    ArrayList<Layer> layers = new ArrayList<>();

    public NeuralNetwork(int i, boolean z, LossFunction lossFunction) {
        this.BIAS = z;
        this.INPUT_NUM = i;
        this.LOSS_FUNCTION = lossFunction;
    }

    public void push(int i, ActivationFunction activationFunction) {
        int i2 = this.layers.isEmpty() ? this.INPUT_NUM : this.layers.get(this.layers.size() - 1).OUTPUT_NUM;
        Layer softmaxLayer = activationFunction instanceof Softmax ? new SoftmaxLayer(i2, i, this) : new StdLayer(i2, i, activationFunction, this);
        this.layers.add(softmaxLayer);
        softmaxLayer.finalLayer = true;
    }

    public void setup() {
        if (this.layers.isEmpty()) {
            throw new RuntimeException("No Layers In the Network");
        }
        for (int i = 0; i < this.layers.size() - 1; i++) {
            this.layers.get(i).finalLayer = false;
        }
    }

    public NeuralNetwork learningRate(double d) {
        this.learningRate = d;
        return this;
    }

    public NeuralNetwork momentum(double d) {
        this.momentum = d;
        return this;
    }

    public double[] feedForward(double[] dArr) {
        double[] dArr2 = (double[]) dArr.clone();
        for (int i = 0; i < this.layers.size(); i++) {
            dArr2 = this.layers.get(i).feedForward(dArr2);
        }
        return dArr2;
    }

    public void backpropogate(double[] dArr, double[] dArr2) {
        feedForward(dArr);
        int size = this.layers.size() - 1;
        double[] backPropogateDeltaRule = this.layers.get(size).backPropogateDeltaRule(dArr2);
        for (int i = size - 1; i >= 0; i--) {
            backPropogateDeltaRule = this.layers.get(i).backPropogate(backPropogateDeltaRule);
        }
        for (int i2 = size; i2 >= 0; i2--) {
            this.layers.get(i2).update();
        }
    }
}
