/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.question_answering;

import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import java.util.ArrayList;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.QuestionAnsweringInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.DLModel;
import org.opensearch.ml.engine.algorithms.question_answering.QuestionAnsweringTranslator;
import org.opensearch.ml.engine.annotation.Function;

@Function(value=FunctionName.QUESTION_ANSWERING)
public class QuestionAnsweringModel
extends DLModel {
    @Generated
    private static final Logger log = LogManager.getLogger(QuestionAnsweringModel.class);

    @Override
    public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfig) throws TranslateException {
        String question = "How is the weather?";
        String context = "The weather is nice, it is beautiful day.";
        Input input = new Input();
        input.add(question);
        input.add(context);
        predictor.predict((Object)input);
    }

    @Override
    public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
        MLInputDataset inputDataSet = mlInput.getInputDataset();
        ArrayList<ModelTensors> tensorOutputs = new ArrayList<ModelTensors>();
        QuestionAnsweringInputDataSet qaInputDataSet = (QuestionAnsweringInputDataSet)inputDataSet;
        String question = qaInputDataSet.getQuestion();
        String context = qaInputDataSet.getContext();
        Input input = new Input();
        input.add(question);
        input.add(context);
        Output output = (Output)this.getPredictor().predict((Object)input);
        tensorOutputs.add(this.parseModelTensorOutput(output, null));
        return new ModelTensorOutput(tensorOutputs);
    }

    @Override
    public Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig) throws IllegalArgumentException {
        return new QuestionAnsweringTranslator();
    }

    @Override
    public TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig) {
        return null;
    }
}

