/*
 * Decompiled with CFR 0.152.
 */
package org.apache.solr.ltr.model;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.util.SolrPluginUtils;

public class NeuralNetworkModel
extends LTRScoringModel {
    private List<Layer> layers;

    protected Layer createLayer(Object o) {
        DefaultLayer layer = new DefaultLayer();
        if (o != null) {
            SolrPluginUtils.invokeSetters((Object)layer, ((Map)o).entrySet());
        }
        return layer;
    }

    public void setLayers(Object layers) {
        this.layers = new ArrayList<Layer>();
        for (Object o : (List)layers) {
            Layer layer = this.createLayer(o);
            this.layers.add(layer);
        }
    }

    public NeuralNetworkModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName, List<Feature> allFeatures, Map<String, Object> params) {
        super(name, features, norms, featureStoreName, allFeatures, params);
    }

    @Override
    protected void validate() throws ModelException {
        super.validate();
        int inputDim = this.features.size();
        for (Layer layer : this.layers) {
            inputDim = layer.validate(inputDim);
        }
        if (inputDim != 1) {
            throw new ModelException("The output matrix for model \"" + this.name + "\" has " + Integer.toString(inputDim) + " rows, but should only have one.");
        }
    }

    @Override
    public float score(float[] inputFeatures) {
        float[] outputVec = inputFeatures;
        for (Layer layer : this.layers) {
            outputVec = layer.calculateOutput(outputVec);
        }
        return outputVec[0];
    }

    @Override
    public Explanation explain(LeafReaderContext context, int doc, float finalScore, List<Explanation> featureExplanations) {
        int i;
        StringBuilder modelDescription = new StringBuilder();
        modelDescription.append("(name=").append(this.getName());
        modelDescription.append(",featureValues=[");
        for (i = 0; i < featureExplanations.size(); ++i) {
            Explanation featureExplain = featureExplanations.get(i);
            if (i > 0) {
                modelDescription.append(',');
            }
            String key = ((Feature)((Object)this.features.get(i))).getName();
            modelDescription.append(key).append('=').append(featureExplain.getValue());
        }
        modelDescription.append("],layers=[");
        for (i = 0; i < this.layers.size(); ++i) {
            if (i > 0) {
                modelDescription.append(',');
            }
            modelDescription.append(this.layers.get(i).describe());
        }
        modelDescription.append("])");
        return Explanation.match((float)finalScore, (String)modelDescription.toString(), (Explanation[])new Explanation[0]);
    }

    public class DefaultLayer
    implements Layer {
        private int layerID;
        private float[][] weightMatrix;
        private int matrixRows;
        private int matrixCols;
        private float[] biasVector;
        private int numUnits;
        protected String activationStr;
        protected Activation activation;

        public DefaultLayer() {
            this.layerID = NeuralNetworkModel.this.layers.size();
        }

        public void setMatrix(Object matrixObj) {
            List matrix = (List)matrixObj;
            this.matrixRows = matrix.size();
            this.matrixCols = ((List)matrix.get(0)).size();
            this.weightMatrix = new float[this.matrixRows][this.matrixCols];
            for (int i = 0; i < this.matrixRows; ++i) {
                for (int j = 0; j < this.matrixCols; ++j) {
                    this.weightMatrix[i][j] = ((Double)((List)matrix.get(i)).get(j)).floatValue();
                }
            }
        }

        public void setBias(Object biasObj) {
            List vector = (List)biasObj;
            this.numUnits = vector.size();
            this.biasVector = new float[this.numUnits];
            for (int i = 0; i < this.numUnits; ++i) {
                this.biasVector[i] = ((Double)vector.get(i)).floatValue();
            }
        }

        public void setActivation(Object activationStr) {
            switch (this.activationStr = (String)activationStr) {
                case "relu": {
                    this.activation = new Activation(){

                        @Override
                        public float apply(float in) {
                            return in < 0.0f ? 0.0f : in;
                        }
                    };
                    break;
                }
                case "sigmoid": {
                    this.activation = new Activation(){

                        @Override
                        public float apply(float in) {
                            return (float)(1.0 / (1.0 + Math.exp(-in)));
                        }
                    };
                    break;
                }
                case "identity": {
                    this.activation = new Activation(){

                        @Override
                        public float apply(float in) {
                            return in;
                        }
                    };
                    break;
                }
                default: {
                    this.activation = null;
                }
            }
        }

        @Override
        public float[] calculateOutput(float[] inputVec) {
            float[] outputVec = new float[this.matrixRows];
            for (int i = 0; i < this.matrixRows; ++i) {
                float outputVal = this.biasVector[i];
                for (int j = 0; j < this.matrixCols; ++j) {
                    outputVal += this.weightMatrix[i][j] * inputVec[j];
                }
                outputVec[i] = this.activation.apply(outputVal);
            }
            return outputVec;
        }

        @Override
        public int validate(int inputDim) throws ModelException {
            if (this.numUnits != this.matrixRows) {
                throw new ModelException("Dimension mismatch in model \"" + NeuralNetworkModel.this.name + "\". Layer " + Integer.toString(this.layerID) + " has " + Integer.toString(this.numUnits) + " bias weights but " + Integer.toString(this.matrixRows) + " weight matrix rows.");
            }
            if (this.activation == null) {
                throw new ModelException("Invalid activation function (\"" + this.activationStr + "\") in layer " + Integer.toString(this.layerID) + " of model \"" + NeuralNetworkModel.this.name + "\".");
            }
            if (inputDim != this.matrixCols) {
                if (this.layerID == 0) {
                    throw new ModelException("Dimension mismatch in model \"" + NeuralNetworkModel.this.name + "\". The input has " + Integer.toString(inputDim) + " features, but the weight matrix for layer 0 has " + Integer.toString(this.matrixCols) + " columns.");
                }
                throw new ModelException("Dimension mismatch in model \"" + NeuralNetworkModel.this.name + "\". The weight matrix for layer " + Integer.toString(this.layerID - 1) + " has " + Integer.toString(inputDim) + " rows, but the weight matrix for layer " + Integer.toString(this.layerID) + " has " + Integer.toString(this.matrixCols) + " columns.");
            }
            return this.matrixRows;
        }

        @Override
        public String describe() {
            StringBuilder sb = new StringBuilder();
            sb.append("(matrix=").append(Integer.toString(this.matrixRows)).append('x').append(Integer.toString(this.matrixCols)).append(",activation=").append(this.activationStr).append(")");
            return sb.toString();
        }
    }

    public static interface Layer {
        public float[] calculateOutput(float[] var1);

        public int validate(int var1) throws ModelException;

        public String describe();
    }

    protected static interface Activation {
        public float apply(float var1);
    }
}

