/*
 * Decompiled with CFR 0.152.
 */
package com.o19s.es.ltr.ranker.parser;

import com.o19s.es.ltr.feature.FeatureSet;
import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTree;
import com.o19s.es.ltr.ranker.normalizer.Normalizer;
import com.o19s.es.ltr.ranker.normalizer.Normalizers;
import com.o19s.es.ltr.ranker.parser.LtrRankerParser;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.ListIterator;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ObjectParser;
import org.opensearch.core.xcontent.XContentParseException;
import org.opensearch.core.xcontent.XContentParser;

public class XGBoostJsonParser
implements LtrRankerParser {
    public static final String TYPE = "model/xgboost+json";

    @Override
    public NaiveAdditiveDecisionTree parse(FeatureSet set, String model) {
        XGBoostDefinition modelDefinition;
        try (XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, model);){
            modelDefinition = XGBoostDefinition.parse(parser, set);
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Cannot parse model", e);
        }
        NaiveAdditiveDecisionTree.Node[] trees = modelDefinition.getTrees(set);
        float[] weights = new float[trees.length];
        Arrays.fill(weights, 1.0f);
        return new NaiveAdditiveDecisionTree(trees, weights, set.size(), modelDefinition.normalizer);
    }

    private static class XGBoostDefinition {
        private static final ObjectParser<XGBoostDefinition, FeatureSet> PARSER = new ObjectParser("xgboost_definition", XGBoostDefinition::new);
        private Normalizer normalizer = Normalizers.get("noop");
        private List<SplitParserState> splitParserStates;

        public static XGBoostDefinition parse(XContentParser parser, FeatureSet set) throws IOException {
            XGBoostDefinition definition;
            XContentParser.Token startToken = parser.nextToken();
            if (startToken == XContentParser.Token.START_OBJECT) {
                try {
                    definition = (XGBoostDefinition)PARSER.apply(parser, (Object)set);
                }
                catch (XContentParseException e) {
                    throw new ParsingException(parser.getTokenLocation(), "Unable to parse XGBoost object", (Throwable)e, new Object[0]);
                }
                if (definition.splitParserStates == null) {
                    throw new ParsingException(parser.getTokenLocation(), "XGBoost model missing required field [splits]", new Object[0]);
                }
            } else if (startToken == XContentParser.Token.START_ARRAY) {
                definition = new XGBoostDefinition();
                definition.splitParserStates = new ArrayList<SplitParserState>();
                while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
                    definition.splitParserStates.add(SplitParserState.parse(parser, set));
                }
            } else {
                throw new ParsingException(parser.getTokenLocation(), "Expected [START_ARRAY] or [START_OBJECT] but got [" + String.valueOf(startToken) + "]", new Object[0]);
            }
            if (definition.splitParserStates.size() == 0) {
                throw new ParsingException(parser.getTokenLocation(), "XGBoost model must define at lease one tree", new Object[0]);
            }
            return definition;
        }

        XGBoostDefinition() {
        }

        void setNormalizer(String objectiveName) {
            switch (objectiveName) {
                case "binary:logitraw": 
                case "rank:pairwise": 
                case "reg:linear": {
                    this.normalizer = Normalizers.get("noop");
                    break;
                }
                case "binary:logistic": 
                case "reg:logistic": {
                    this.normalizer = Normalizers.get("sigmoid");
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Objective [" + objectiveName + "] is not a valid XGBoost objective");
                }
            }
        }

        void setSplitParserStates(List<SplitParserState> splitParserStates) {
            this.splitParserStates = splitParserStates;
        }

        NaiveAdditiveDecisionTree.Node[] getTrees(FeatureSet set) {
            NaiveAdditiveDecisionTree.Node[] trees = new NaiveAdditiveDecisionTree.Node[this.splitParserStates.size()];
            ListIterator<SplitParserState> it = this.splitParserStates.listIterator();
            while (it.hasNext()) {
                trees[it.nextIndex()] = it.next().toNode(set);
            }
            return trees;
        }

        static {
            PARSER.declareString(XGBoostDefinition::setNormalizer, new ParseField("objective", new String[0]));
            PARSER.declareObjectArray(XGBoostDefinition::setSplitParserStates, SplitParserState::parse, new ParseField("splits", new String[0]));
        }
    }

    private static class SplitParserState {
        private static final ObjectParser<SplitParserState, FeatureSet> PARSER = new ObjectParser("node", SplitParserState::new);
        private Integer nodeId;
        private Integer depth;
        private String split;
        private Float threshold;
        private Integer rightNodeId;
        private Integer leftNodeId;
        private Integer missingNodeId;
        private Float leaf;
        private List<SplitParserState> children;

        private SplitParserState() {
        }

        public static SplitParserState parse(XContentParser parser, FeatureSet set) {
            SplitParserState split = (SplitParserState)PARSER.apply(parser, (Object)set);
            if (split.isSplit()) {
                if (!split.splitHasAllFields()) {
                    throw new ParsingException(parser.getTokenLocation(), "This split does not have all the required fields", new Object[0]);
                }
                if (!split.splitHasValidChildren()) {
                    throw new ParsingException(parser.getTokenLocation(), "Split structure is invalid, yes, no and/or missing branches does not point to the proper children.", new Object[0]);
                }
                if (!set.hasFeature(split.split)) {
                    throw new ParsingException(parser.getTokenLocation(), "Unknown feature [" + split.split + "]", new Object[0]);
                }
            } else if (!split.leafHasAllFields()) {
                throw new ParsingException(parser.getTokenLocation(), "This leaf does not have all the required fields", new Object[0]);
            }
            return split;
        }

        void setNodeId(Integer nodeId) {
            this.nodeId = nodeId;
        }

        void setDepth(Integer depth) {
            this.depth = depth;
        }

        void setSplit(String split) {
            this.split = split;
        }

        void setThreshold(Float threshold) {
            this.threshold = threshold;
        }

        void setRightNodeId(Integer rightNodeId) {
            this.rightNodeId = rightNodeId;
        }

        void setLeftNodeId(Integer leftNodeId) {
            this.leftNodeId = leftNodeId;
        }

        void setMissingNodeId(Integer missingNodeId) {
            this.missingNodeId = missingNodeId;
        }

        void setLeaf(Float leaf) {
            this.leaf = leaf;
        }

        void setChildren(List<SplitParserState> children) {
            this.children = children;
        }

        boolean splitHasAllFields() {
            return this.nodeId != null && this.threshold != null && this.split != null && this.leftNodeId != null && this.rightNodeId != null && this.depth != null && this.children != null && this.children.size() == 2;
        }

        boolean leafHasAllFields() {
            return this.nodeId != null && this.leaf != null;
        }

        boolean splitHasValidChildren() {
            return this.children.size() == 2 && this.leftNodeId.equals(this.children.get((int)0).nodeId) && this.rightNodeId.equals(this.children.get((int)1).nodeId);
        }

        boolean isSplit() {
            return this.leaf == null;
        }

        NaiveAdditiveDecisionTree.Node toNode(FeatureSet set) {
            if (this.isSplit()) {
                return new NaiveAdditiveDecisionTree.Split(this.children.get(0).toNode(set), this.children.get(1).toNode(set), set.featureOrdinal(this.split), this.threshold.floatValue());
            }
            return new NaiveAdditiveDecisionTree.Leaf(this.leaf.floatValue());
        }

        static {
            PARSER.declareInt(SplitParserState::setNodeId, new ParseField("nodeid", new String[0]));
            PARSER.declareInt(SplitParserState::setDepth, new ParseField("depth", new String[0]));
            PARSER.declareString(SplitParserState::setSplit, new ParseField("split", new String[0]));
            PARSER.declareFloat(SplitParserState::setThreshold, new ParseField("split_condition", new String[0]));
            PARSER.declareInt(SplitParserState::setRightNodeId, new ParseField("no", new String[0]));
            PARSER.declareInt(SplitParserState::setLeftNodeId, new ParseField("yes", new String[0]));
            PARSER.declareInt(SplitParserState::setMissingNodeId, new ParseField("missing", new String[0]));
            PARSER.declareFloat(SplitParserState::setLeaf, new ParseField("leaf", new String[0]));
            PARSER.declareObjectArray(SplitParserState::setChildren, SplitParserState::parse, new ParseField("children", new String[0]));
            PARSER.declareFloat(SplitParserState::setThreshold, new ParseField("split_condition", new String[0]));
        }
    }
}

