/*
 * Decompiled with CFR 0.152.
 */
package dmonner.xlbp.layer;

import dmonner.xlbp.NetworkCopier;
import dmonner.xlbp.Responsibilities;
import dmonner.xlbp.layer.AbstractInternalLayer;

public class RepulsionLayer
extends AbstractInternalLayer {
    private static final long serialVersionUID = 1L;
    private float[] mu;
    private float[] buf;
    private boolean mu_init;
    private final float adjust;
    private final float retain;
    private final float amount;

    public RepulsionLayer(RepulsionLayer that, NetworkCopier copier) {
        super(that, copier);
        this.mu = that.mu == null ? null : (float[])that.mu.clone();
        this.buf = that.buf == null ? null : (float[])that.mu.clone();
        this.adjust = that.adjust;
        this.retain = that.retain;
        this.amount = that.amount;
        this.mu_init = that.mu_init;
    }

    public RepulsionLayer(String name, int size) {
        this(name, size, 0.95f, 0.1f);
    }

    public RepulsionLayer(String name, int size, float retain, float amount) {
        super(name, size);
        if (0.0f >= retain || retain >= 1.0f) {
            throw new IllegalArgumentException("Retain value must be in interval (0, 1).");
        }
        this.adjust = 1.0f - retain;
        this.retain = retain;
        this.amount = amount;
    }

    @Override
    public void activateTest() {
    }

    @Override
    public void activateTrain() {
        if (this.mu_init) {
            for (int j = 0; j < this.size; ++j) {
                this.mu[j] = this.y[j] * this.adjust + this.mu[j] * this.retain;
            }
        } else {
            System.arraycopy(this.y, 0, this.mu, 0, this.size);
            this.mu_init = true;
        }
    }

    @Override
    public void build() {
        if (!this.built) {
            super.build();
            this.upstream.build();
            this.y = this.upstream.getActivations();
            this.mu = new float[this.size];
            this.buf = new float[this.size];
            this.d = new Responsibilities(this.size);
            this.built = true;
        }
    }

    @Override
    public void clearActivations() {
    }

    @Override
    public RepulsionLayer copy(NetworkCopier copier) {
        return new RepulsionLayer(this, copier);
    }

    @Override
    public RepulsionLayer copy(String nameSuffix) {
        return this.copy(new NetworkCopier(nameSuffix));
    }

    @Override
    public void updateEligibilities() {
        if (this.downstreamCopyLayer != null) {
            this.downstream.updateUpstreamResponsibilities(this.myIndexInDownstream);
        }
    }

    @Override
    public void updateResponsibilities() {
        if (this.downstreamCopyLayer == null) {
            this.downstream.updateUpstreamResponsibilities(this.myIndexInDownstream);
        }
    }

    @Override
    public void updateUpstreamResponsibilities(int index) {
        this.upstream.getResponsibilities(this.myIndexInUpstream).copyPlusScaledDiff(this.d, this.mu, this.y, this.amount);
    }
}

