/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math.optimisers;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.function.DoubleUnaryOperator;
import org.tribuo.math.Parameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.Tensor;

public class Adam
implements StochasticGradientOptimiser {
    @Config(description="Learning rate to scale the gradients by.")
    private double initialLearningRate = 0.001;
    @Config(description="The beta one parameter.")
    private double betaOne = 0.9;
    @Config(description="The beta two parameter.")
    private double betaTwo = 0.99;
    @Config(description="Epsilon for numerical stability.")
    private double epsilon = 1.0E-6;
    private int iterations = 0;
    private Tensor[] firstMoment;
    private Tensor[] secondMoment;

    public Adam(double initialLearningRate, double betaOne, double betaTwo, double epsilon) {
        this.initialLearningRate = initialLearningRate;
        this.betaOne = betaOne;
        this.betaTwo = betaTwo;
        this.epsilon = epsilon;
        this.iterations = 0;
    }

    public Adam(double initialLearningRate, double epsilon) {
        this(initialLearningRate, 0.9, 0.999, epsilon);
    }

    public Adam() {
        this(0.001, 0.9, 0.999, 1.0E-6);
    }

    @Override
    public void initialise(Parameters parameters) {
        this.firstMoment = parameters.getEmptyCopy();
        this.secondMoment = parameters.getEmptyCopy();
        this.iterations = 0;
    }

    @Override
    public Tensor[] step(Tensor[] updates, double weight) {
        ++this.iterations;
        double learningRate = this.initialLearningRate * Math.sqrt(1.0 - Math.pow(this.betaTwo, this.iterations)) / (1.0 - Math.pow(this.betaOne, this.iterations));
        DoubleUnaryOperator scale = a -> a * learningRate;
        for (int i = 0; i < updates.length; ++i) {
            this.firstMoment[i].scaleInPlace(this.betaOne);
            this.firstMoment[i].intersectAndAddInPlace(updates[i], a -> a * (1.0 - this.betaOne));
            this.secondMoment[i].scaleInPlace(this.betaTwo);
            this.secondMoment[i].intersectAndAddInPlace(updates[i], a -> a * a * (1.0 - this.betaTwo));
            updates[i].scaleInPlace(0.0);
            updates[i].intersectAndAddInPlace(this.firstMoment[i], scale);
            updates[i].hadamardProductInPlace(this.secondMoment[i], a -> Math.sqrt(a) + this.epsilon);
        }
        return updates;
    }

    public String toString() {
        return "Adam(learningRate=" + this.initialLearningRate + ",betaOne=" + this.betaOne + ",betaTwo=" + this.betaTwo + ",epsilon=" + this.epsilon + ")";
    }

    @Override
    public void reset() {
        this.firstMoment = null;
        this.secondMoment = null;
        this.iterations = 0;
    }

    @Override
    public Adam copy() {
        return new Adam(this.initialLearningRate, this.betaOne, this.betaTwo, this.epsilon);
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "StochasticGradientOptimiser");
    }
}

