/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.remote;

import com.google.common.collect.ImmutableMap;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Predicate;
import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.text.StringEscapeUtils;
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.utils.ScriptUtils;
import org.opensearch.script.ScriptService;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.auth.signer.params.Aws4SignerParams;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.regions.Region;

public class ConnectorUtils {
    private static final Aws4Signer signer = Aws4Signer.create();

    public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connector connector, Map<String, String> parameters, ScriptService scriptService) {
        RemoteInferenceInputDataSet inputData;
        if (mlInput == null) {
            throw new IllegalArgumentException("Input is null");
        }
        if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
            Optional<String> processedResponse;
            TextDocsInputDataSet inputDataSet = (TextDocsInputDataSet)mlInput.getInputDataset();
            ArrayList docs = new ArrayList(inputDataSet.getDocs());
            ImmutableMap params = ImmutableMap.of((Object)"text_docs", docs);
            Optional predictAction = connector.findPredictAction();
            if (!predictAction.isPresent()) {
                throw new IllegalArgumentException("no predict action found");
            }
            String preProcessFunction = ((ConnectorAction)predictAction.get()).getPreProcessFunction();
            if (preProcessFunction == null) {
                throw new IllegalArgumentException("Must provide pre_process_function for predict action to process text docs input.");
            }
            if (preProcessFunction != null && preProcessFunction.contains("${parameters")) {
                StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
                preProcessFunction = substitutor.replace(preProcessFunction);
            }
            if (!(processedResponse = ScriptUtils.executePreprocessFunction(scriptService, preProcessFunction, (Map<String, Object>)params)).isPresent()) {
                throw new IllegalArgumentException("Wrong input");
            }
            Map map = (Map)ScriptUtils.gson.fromJson(processedResponse.get(), Map.class);
            Map parametersMap = (Map)map.get("parameters");
            HashMap processedParameters = new HashMap();
            for (String key : parametersMap.keySet()) {
                try {
                    AccessController.doPrivileged(() -> {
                        if (parametersMap.get(key) instanceof String) {
                            processedParameters.put(key, (String)parametersMap.get(key));
                        } else {
                            processedParameters.put(key, ScriptUtils.gson.toJson(parametersMap.get(key)));
                        }
                        return null;
                    });
                }
                catch (PrivilegedActionException e) {
                    throw new RuntimeException(e);
                }
            }
            inputData = RemoteInferenceInputDataSet.builder().parameters(processedParameters).build();
        } else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
            inputData = (RemoteInferenceInputDataSet)mlInput.getInputDataset();
        } else {
            throw new IllegalArgumentException("Wrong input type");
        }
        if (inputData.getParameters() != null) {
            HashMap newParameters = new HashMap();
            inputData.getParameters().entrySet().forEach(entry -> {
                if (entry.getValue() == null) {
                    newParameters.put((String)entry.getKey(), (String)entry.getValue());
                } else if (StringUtils.isJson((String)((String)entry.getValue()))) {
                    newParameters.put((String)entry.getKey(), (String)entry.getValue());
                } else {
                    newParameters.put((String)entry.getKey(), StringEscapeUtils.escapeJson((String)((String)entry.getValue())));
                }
            });
            inputData.setParameters(newParameters);
        }
        return inputData;
    }

    public static ModelTensors processOutput(String modelResponse, Connector connector, ScriptService scriptService, Map<String, String> parameters) throws IOException {
        if (modelResponse == null) {
            throw new IllegalArgumentException("model response is null");
        }
        ArrayList modelTensors = new ArrayList();
        Optional predictAction = connector.findPredictAction();
        if (!predictAction.isPresent()) {
            throw new IllegalArgumentException("no predict action found");
        }
        String postProcessFunction = ((ConnectorAction)predictAction.get()).getPostProcessFunction();
        if (postProcessFunction != null && postProcessFunction.contains("${parameters")) {
            StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
            postProcessFunction = substitutor.replace(postProcessFunction);
        }
        Optional<String> processedResponse = ScriptUtils.executePostprocessFunction(scriptService, postProcessFunction, modelResponse);
        String response = processedResponse.orElse(modelResponse);
        if (parameters.get("response_filter") == null) {
            connector.parseResponse((Object)response, modelTensors, postProcessFunction != null && processedResponse.isPresent());
        } else {
            Object filteredResponse = JsonPath.parse((String)response).read(parameters.get("response_filter"), new Predicate[0]);
            connector.parseResponse(filteredResponse, modelTensors, postProcessFunction != null && processedResponse.isPresent());
        }
        ModelTensors tensors = ModelTensors.builder().mlModelTensors(modelTensors).build();
        return tensors;
    }

    public static SdkHttpFullRequest signRequest(SdkHttpFullRequest request, String accessKey, String secretKey, String sessionToken, String signingName, String region) {
        AwsBasicCredentials credentials = sessionToken == null ? AwsBasicCredentials.create((String)accessKey, (String)secretKey) : AwsSessionCredentials.create((String)accessKey, (String)secretKey, (String)sessionToken);
        Aws4SignerParams params = Aws4SignerParams.builder().awsCredentials((AwsCredentials)credentials).signingName(signingName).signingRegion(Region.of((String)region)).build();
        return signer.sign(request, params);
    }
}

