/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine;

import java.nio.file.Path;
import java.util.Locale;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.Output;
import org.opensearch.ml.engine.Executable;
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.MLExecutable;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.TrainAndPredictable;
import org.opensearch.ml.engine.Trainable;
import org.opensearch.ml.engine.encryptor.Encryptor;

public class MLEngine {
    @Generated
    private static final Logger log = LogManager.getLogger(MLEngine.class);
    public static final String REGISTER_MODEL_FOLDER = "register";
    public static final String DEPLOY_MODEL_FOLDER = "deploy";
    private final String MODEL_REPO = "https://artifacts.opensearch.org/models/ml-models";
    private final Path mlConfigPath;
    private final Path mlCachePath;
    private final Path mlModelsCachePath;
    private Encryptor encryptor;

    public MLEngine(Path opensearchDataFolder, Encryptor encryptor) {
        this.mlCachePath = opensearchDataFolder.resolve("ml_cache");
        this.mlModelsCachePath = this.mlCachePath.resolve("models_cache");
        this.mlConfigPath = this.mlCachePath.resolve("config");
        this.encryptor = encryptor;
    }

    public String getPrebuiltModelMetaListPath() {
        return "https://artifacts.opensearch.org/models/ml-models/model_listing/pre_trained_models.json";
    }

    public String getPrebuiltModelConfigPath(String modelName, String version, MLModelFormat modelFormat) {
        String format = modelFormat.name().toLowerCase(Locale.ROOT);
        return String.format("%s/%s/%s/%s/config.json", "https://artifacts.opensearch.org/models/ml-models", modelName, version, format, Locale.ROOT);
    }

    public String getPrebuiltModelPath(String modelName, String version, MLModelFormat modelFormat) {
        int index = modelName.indexOf("/") + 1;
        String format = modelFormat.name().toLowerCase(Locale.ROOT);
        String modelZipFileName = modelName.substring(index).replace("/", "_") + "-" + version + "-" + format;
        return String.format("%s/%s/%s/%s/%s.zip", "https://artifacts.opensearch.org/models/ml-models", modelName, version, format, modelZipFileName, Locale.ROOT);
    }

    public Path getRegisterModelPath(String modelId, String modelName, String version) {
        return this.getRegisterModelPath(modelId).resolve(version).resolve(modelName);
    }

    public Path getRegisterModelPath(String modelId) {
        return this.getRegisterModelRootPath().resolve(modelId);
    }

    public Path getRegisterModelRootPath() {
        return this.mlModelsCachePath.resolve(REGISTER_MODEL_FOLDER);
    }

    public Path getDeployModelPath(String modelId) {
        return this.getDeployModelRootPath().resolve(modelId);
    }

    public String getDeployModelZipPath(String modelId, String modelName) {
        return this.mlModelsCachePath.resolve(DEPLOY_MODEL_FOLDER).resolve(modelId).resolve(modelName) + ".zip";
    }

    public Path getDeployModelRootPath() {
        return this.mlModelsCachePath.resolve(DEPLOY_MODEL_FOLDER);
    }

    public Path getDeployModelChunkPath(String modelId, Integer chunkNumber) {
        return this.mlModelsCachePath.resolve(DEPLOY_MODEL_FOLDER).resolve(modelId).resolve("chunks").resolve("" + chunkNumber);
    }

    public Path getModelCachePath(String modelId, String modelName, String version) {
        return this.getModelCachePath(modelId).resolve(version).resolve(modelName);
    }

    public Path getModelCachePath(String modelId) {
        return this.getModelCacheRootPath().resolve(modelId);
    }

    public Path getModelCacheRootPath() {
        return this.mlModelsCachePath.resolve("models");
    }

    public MLModel train(Input input) {
        this.validateMLInput(input);
        MLInput mlInput = (MLInput)input;
        Trainable trainable = (Trainable)MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
        if (trainable == null) {
            throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
        }
        return trainable.train(mlInput);
    }

    public Predictable deploy(MLModel mlModel, Map<String, Object> params) {
        Predictable predictable = (Predictable)MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
        predictable.initModel(mlModel, params, this.encryptor);
        return predictable;
    }

    public MLExecutable deployExecute(MLModel mlModel, Map<String, Object> params) {
        MLExecutable executable = (MLExecutable)MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
        executable.initModel(mlModel, params);
        return executable;
    }

    public MLOutput predict(Input input, MLModel model) {
        this.validateMLInput(input);
        MLInput mlInput = (MLInput)input;
        Predictable predictable = (Predictable)MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
        if (predictable == null) {
            throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
        }
        return predictable.predict(mlInput, model);
    }

    public MLOutput trainAndPredict(Input input) {
        this.validateMLInput(input);
        MLInput mlInput = (MLInput)input;
        TrainAndPredictable trainAndPredictable = (TrainAndPredictable)MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
        if (trainAndPredictable == null) {
            throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
        }
        return trainAndPredictable.trainAndPredict(mlInput);
    }

    public Output execute(Input input) throws Exception {
        this.validateInput(input);
        if (input.getFunctionName() == FunctionName.METRICS_CORRELATION) {
            MLExecutable executable = (MLExecutable)MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class);
            if (executable == null) {
                throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName());
            }
            return executable.execute(input);
        }
        Executable executable = (Executable)MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class);
        if (executable == null) {
            throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName());
        }
        return executable.execute(input);
    }

    private void validateMLInput(Input input) {
        DataFrame dataFrame;
        this.validateInput(input);
        if (!(input instanceof MLInput)) {
            throw new IllegalArgumentException("Input should be MLInput");
        }
        MLInput mlInput = (MLInput)input;
        MLInputDataset inputDataset = mlInput.getInputDataset();
        if (inputDataset == null) {
            throw new IllegalArgumentException("Input data set should not be null");
        }
        if (inputDataset instanceof DataFrameInputDataset && ((dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame()) == null || dataFrame.size() == 0)) {
            throw new IllegalArgumentException("Input data frame should not be null or empty");
        }
    }

    private void validateInput(Input input) {
        if (input == null) {
            throw new IllegalArgumentException("Input should not be null");
        }
        if (input.getFunctionName() == null) {
            throw new IllegalArgumentException("Function name should not be null");
        }
    }

    public String encrypt(String credential) {
        return this.encryptor.encrypt(credential);
    }

    @Generated
    public Path getMlConfigPath() {
        return this.mlConfigPath;
    }

    @Generated
    public Path getMlCachePath() {
        return this.mlCachePath;
    }
}

