/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.question_answering;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.QuestionAnsweringInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.DLModel;
import org.opensearch.ml.engine.algorithms.question_answering.QuestionAnsweringTranslator;
import org.opensearch.ml.engine.algorithms.question_answering.SentenceHighlightingQATranslator;
import org.opensearch.ml.engine.annotation.Function;

@Function(value=FunctionName.QUESTION_ANSWERING)
public class QuestionAnsweringModel
extends DLModel {
    @Generated
    private static final Logger log = LogManager.getLogger(QuestionAnsweringModel.class);
    private MLModelConfig modelConfig;
    private Translator<Input, Output> translator;

    @Override
    public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfig) throws TranslateException {
        if (predictor == null) {
            throw new IllegalArgumentException("predictor is null");
        }
        if (modelId == null) {
            throw new IllegalArgumentException("model id is null");
        }
        if (modelConfig != null) {
            this.modelConfig = modelConfig;
        }
        Input input = new Input();
        if (this.isSentenceHighlightingModel()) {
            input.add("question", "How is the weather?");
            input.add("context", "The weather is nice, it is beautiful day. The sun is shining. The sky is blue.");
            input.add("chunk", "0");
        } else {
            input.add("How is the weather?");
            input.add("The weather is nice, it is beautiful day. The sun is shining. The sky is blue.");
        }
        predictor.predict((Object)input);
    }

    @Override
    public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
        MLInputDataset inputDataSet = mlInput.getInputDataset();
        QuestionAnsweringInputDataSet qaInputDataSet = (QuestionAnsweringInputDataSet)inputDataSet;
        String question = qaInputDataSet.getQuestion();
        String context = qaInputDataSet.getContext();
        if (this.isSentenceHighlightingModel()) {
            return this.predictSentenceHighlightingQA(question, context);
        }
        return this.predictStandardQA(question, context);
    }

    private boolean isSentenceHighlightingModel() {
        return this.modelConfig != null && "sentence_highlighting".equalsIgnoreCase(this.modelConfig.getModelType());
    }

    private ModelTensorOutput predictStandardQA(String question, String context) throws TranslateException {
        Input input = new Input();
        input.add(question);
        input.add(context);
        try {
            Output output = (Output)this.getPredictor().predict((Object)input);
            ModelTensors tensors = this.parseModelTensorOutput(output, null);
            return new ModelTensorOutput(List.of(tensors));
        }
        catch (Exception e) {
            log.error("Error processing standard QA model prediction", (Throwable)e);
            throw new TranslateException("Failed to process standard QA model prediction", (Throwable)e);
        }
    }

    private ModelTensorOutput predictSentenceHighlightingQA(String question, String context) throws TranslateException {
        SentenceHighlightingQATranslator translator = (SentenceHighlightingQATranslator)this.getTranslator("PyTorch", this.modelConfig);
        try {
            ArrayList<Map<String, Object>> allHighlights = new ArrayList<Map<String, Object>>();
            this.processChunk(question, context, "0", allHighlights);
            Encoding encodings = translator.getTokenizer().encode(question, context);
            Encoding[] overflowEncodings = encodings.getOverflowing();
            if (overflowEncodings != null && overflowEncodings.length > 0) {
                for (int i = 0; i < overflowEncodings.length; ++i) {
                    this.processChunk(question, context, String.valueOf(i + 1), allHighlights);
                }
            }
            return this.createHighlightOutput(allHighlights);
        }
        catch (Exception e) {
            log.error("Error processing sentence highlighting model prediction", (Throwable)e);
            throw new TranslateException("Failed to process chunks for sentence highlighting", (Throwable)e);
        }
    }

    private void processChunk(String question, String context, String chunkNumber, List<Map<String, Object>> allHighlights) throws TranslateException {
        Input chunkInput = new Input();
        chunkInput.add("question", question);
        chunkInput.add("context", context);
        chunkInput.add("chunk", chunkNumber);
        List outputs = this.getPredictor().batchPredict(List.of(chunkInput));
        if (outputs.isEmpty()) {
            return;
        }
        for (Output output : outputs) {
            ModelTensors tensors = this.parseModelTensorOutput(output, null);
            allHighlights.addAll(this.extractHighlights(tensors));
        }
    }

    private List<Map<String, Object>> extractHighlights(ModelTensors tensors) throws TranslateException {
        ArrayList<Map<String, Object>> highlights = new ArrayList<Map<String, Object>>();
        for (ModelTensor tensor : tensors.getMlModelTensors()) {
            Map dataAsMap = tensor.getDataAsMap();
            if (dataAsMap == null || !dataAsMap.containsKey("highlights")) continue;
            try {
                List tensorHighlights = (List)dataAsMap.get("highlights");
                highlights.addAll(tensorHighlights);
            }
            catch (ClassCastException e) {
                log.error("Failed to cast highlights data to expected format", (Throwable)e);
                throw new TranslateException("Failed to cast highlights data to expected format", (Throwable)e);
            }
        }
        return highlights;
    }

    private ModelTensorOutput createHighlightOutput(List<Map<String, Object>> highlights) {
        HashMap<String, List<Map<String, Object>>> combinedData = new HashMap<String, List<Map<String, Object>>>();
        List<Map<String, Object>> uniqueSortedHighlights = this.removeDuplicatesAndSort(highlights);
        combinedData.put("highlights", uniqueSortedHighlights);
        ModelTensor combinedTensor = ModelTensor.builder().name("highlights").dataAsMap(combinedData).build();
        return new ModelTensorOutput(List.of(new ModelTensors(List.of(combinedTensor))));
    }

    private List<Map<String, Object>> removeDuplicatesAndSort(List<Map<String, Object>> highlights) {
        HashMap<Number, Map<String, Object>> uniqueMap = new HashMap<Number, Map<String, Object>>();
        for (Map<String, Object> highlight : highlights) {
            Number position = (Number)highlight.get("position");
            if (uniqueMap.containsKey(position)) continue;
            uniqueMap.put(position, highlight);
        }
        ArrayList<Map<String, Object>> uniqueHighlights = new ArrayList<Map<String, Object>>(uniqueMap.values());
        uniqueHighlights.sort((a, b) -> {
            Number posA = (Number)a.get("position");
            Number posB = (Number)b.get("position");
            return Double.compare(posA.doubleValue(), posB.doubleValue());
        });
        return uniqueHighlights;
    }

    @Override
    public Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig) throws IllegalArgumentException {
        if (this.translator == null) {
            this.translator = modelConfig != null && "sentence_highlighting".equalsIgnoreCase(modelConfig.getModelType()) ? SentenceHighlightingQATranslator.create(modelConfig) : new QuestionAnsweringTranslator();
        }
        return this.translator;
    }

    @Override
    public TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig) {
        return null;
    }

    @Generated
    public static QuestionAnsweringModelBuilder builder() {
        return new QuestionAnsweringModelBuilder();
    }

    @Generated
    public QuestionAnsweringModelBuilder toBuilder() {
        return new QuestionAnsweringModelBuilder().modelConfig(this.modelConfig).translator(this.translator);
    }

    @Generated
    public QuestionAnsweringModel(MLModelConfig modelConfig, Translator<Input, Output> translator) {
        this.modelConfig = modelConfig;
        this.translator = translator;
    }

    @Generated
    public QuestionAnsweringModel() {
    }

    @Generated
    public static class QuestionAnsweringModelBuilder {
        @Generated
        private MLModelConfig modelConfig;
        @Generated
        private Translator<Input, Output> translator;

        @Generated
        QuestionAnsweringModelBuilder() {
        }

        @Generated
        public QuestionAnsweringModelBuilder modelConfig(MLModelConfig modelConfig) {
            this.modelConfig = modelConfig;
            return this;
        }

        @Generated
        public QuestionAnsweringModelBuilder translator(Translator<Input, Output> translator) {
            this.translator = translator;
            return this;
        }

        @Generated
        public QuestionAnsweringModel build() {
            return new QuestionAnsweringModel(this.modelConfig, this.translator);
        }

        @Generated
        public String toString() {
            return "QuestionAnsweringModel.QuestionAnsweringModelBuilder(modelConfig=" + String.valueOf(this.modelConfig) + ", translator=" + String.valueOf(this.translator) + ")";
        }
    }
}

