/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.syncup;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import java.io.IOException;
import java.nio.file.Path;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.nodes.TransportNodesAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeRequest;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesResponse;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.utils.FileUtils;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.task.MLTaskCache;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

public class TransportSyncUpOnNodeAction
extends TransportNodesAction<MLSyncUpNodesRequest, MLSyncUpNodesResponse, MLSyncUpNodeRequest, MLSyncUpNodeResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportSyncUpOnNodeAction.class);
    TransportService transportService;
    ModelHelper modelHelper;
    MLTaskManager mlTaskManager;
    MLModelManager mlModelManager;
    ClusterService clusterService;
    ThreadPool threadPool;
    Client client;
    NamedXContentRegistry xContentRegistry;
    MLEngine mlEngine;
    private volatile Integer mlTaskTimeout;
    private final MLModelCacheHelper mlModelCacheHelper;

    @Inject
    public TransportSyncUpOnNodeAction(TransportService transportService, Settings settings, ActionFilters actionFilters, ModelHelper modelHelper, MLTaskManager mlTaskManager, MLModelManager mlModelManager, ClusterService clusterService, ThreadPool threadPool, Client client, NamedXContentRegistry xContentRegistry, MLEngine mlEngine, MLModelCacheHelper mlModelCacheHelper) {
        super("cluster:admin/opensearch/mlinternal/syncup", threadPool, clusterService, transportService, actionFilters, MLSyncUpNodesRequest::new, MLSyncUpNodeRequest::new, "management", MLSyncUpNodeResponse.class);
        this.transportService = transportService;
        this.modelHelper = modelHelper;
        this.mlTaskManager = mlTaskManager;
        this.mlModelManager = mlModelManager;
        this.clusterService = clusterService;
        this.threadPool = threadPool;
        this.client = client;
        this.xContentRegistry = xContentRegistry;
        this.mlEngine = mlEngine;
        this.mlModelCacheHelper = mlModelCacheHelper;
        this.mlTaskTimeout = (Integer)MLCommonsSettings.ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS, it -> {
            this.mlTaskTimeout = it;
        });
    }

    protected MLSyncUpNodesResponse newResponse(MLSyncUpNodesRequest nodesRequest, List<MLSyncUpNodeResponse> responses, List<FailedNodeException> failures) {
        return new MLSyncUpNodesResponse(this.clusterService.getClusterName(), responses, failures);
    }

    protected MLSyncUpNodeRequest newNodeRequest(MLSyncUpNodesRequest request) {
        return new MLSyncUpNodeRequest(request);
    }

    protected MLSyncUpNodeResponse newNodeResponse(StreamInput in) throws IOException {
        return new MLSyncUpNodeResponse(in);
    }

    protected MLSyncUpNodeResponse nodeOperation(MLSyncUpNodeRequest request) {
        return this.createSyncUpNodeResponse(request.getSyncUpNodesRequest());
    }

    private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest syncUpNodesRequest) {
        MLSyncUpInput syncUpInput = syncUpNodesRequest.getSyncUpInput();
        Map addedWorkerNodes = syncUpInput.getAddedWorkerNodes();
        Map removedWorkerNodes = syncUpInput.getRemovedWorkerNodes();
        Map modelRoutingTable = syncUpInput.getModelRoutingTable();
        Map runningDeployModelTasks = syncUpInput.getRunningDeployModelTasks();
        Map deployToAllNodes = syncUpInput.getDeployToAllNodes();
        if (addedWorkerNodes != null && addedWorkerNodes.size() > 0) {
            for (Map.Entry entry : addedWorkerNodes.entrySet()) {
                this.mlModelManager.addModelWorkerNode((String)entry.getKey(), (String[])entry.getValue());
            }
        }
        if (removedWorkerNodes != null && removedWorkerNodes.size() > 0) {
            for (Map.Entry entry : removedWorkerNodes.entrySet()) {
                this.mlModelManager.removeModelWorkerNode((String)entry.getKey(), Optional.ofNullable(deployToAllNodes).orElse((Map)ImmutableMap.of()).containsKey(entry.getKey()), (String[])entry.getValue());
            }
        }
        String[] deployedModelIds = null;
        String[] runningDeployModelTaskIds = null;
        String[] runningDeployModelIds = null;
        if (syncUpInput.isGetDeployedModels()) {
            deployedModelIds = this.mlModelManager.getLocalDeployedModels();
            List<String[]> localRunningDeployModel = this.mlTaskManager.getLocalRunningDeployModelTasks();
            runningDeployModelTaskIds = localRunningDeployModel.get(0);
            runningDeployModelIds = localRunningDeployModel.get(1);
        }
        if (syncUpInput.isClearRoutingTable()) {
            this.mlModelManager.clearRoutingTable();
        } else if (modelRoutingTable != null) {
            for (Map.Entry entry : modelRoutingTable.entrySet()) {
                log.debug("latest routing table for model: {}:  {}", entry.getKey(), (Object)((Set)entry.getValue()).toArray(new String[0]));
            }
            this.mlModelManager.syncModelWorkerNodes(modelRoutingTable);
        }
        this.cleanUpLocalCache(runningDeployModelTasks);
        this.cleanUpLocalCacheFiles();
        return new MLSyncUpNodeResponse(this.clusterService.localNode(), "ok", deployedModelIds, runningDeployModelIds, runningDeployModelTaskIds);
    }

    @VisibleForTesting
    void cleanUpLocalCache(Map<String, Set<String>> runningDeployModelTasks) {
        String[] allTaskIds = this.mlTaskManager.getAllTaskIds();
        if (allTaskIds == null) {
            return;
        }
        for (String taskId : allTaskIds) {
            MLTaskCache mlTaskCache = this.mlTaskManager.getMLTaskCache(taskId);
            MLTask mlTask = mlTaskCache.getMlTask();
            Instant lastUpdateTime = mlTask.getLastUpdateTime();
            Instant now = Instant.now();
            if (!now.isAfter(lastUpdateTime.plusSeconds(this.mlTaskTimeout.intValue()))) continue;
            log.info("ML task timeout. task id: {}, task type: {}", (Object)taskId, (Object)mlTask.getTaskType());
            if (mlTask.getTaskType() == MLTaskType.DEPLOY_MODEL && mlTask.getState() == MLTaskState.CREATED && runningDeployModelTasks != null && runningDeployModelTasks.containsKey(taskId)) continue;
            this.mlTaskManager.updateMLTask(taskId, (Map<String, Object>)ImmutableMap.of((Object)"state", (Object)MLTaskState.FAILED, (Object)"error", (Object)("timeout after " + this.mlTaskTimeout + " seconds")), 10000L, true);
        }
    }

    private void cleanUpLocalCacheFiles() {
        Path modelCacheRootPath;
        Path deployModelRootPath;
        Path registerModelRootPath = this.mlEngine.getRegisterModelRootPath();
        Set modelsInCacheFolder = FileUtils.getFileNames((Path[])new Path[]{registerModelRootPath, deployModelRootPath = this.mlEngine.getDeployModelRootPath(), modelCacheRootPath = this.mlEngine.getModelCacheRootPath()});
        if (modelsInCacheFolder.size() > 0) {
            log.debug("Found {} models in cache folder: {}", (Object)modelsInCacheFolder.size(), (Object)Arrays.toString(modelsInCacheFolder.toArray(new String[0])));
            for (String modelId : modelsInCacheFolder) {
                if (this.mlTaskManager.contains(modelId) || this.mlTaskManager.containsModel(modelId) || this.mlModelManager.isModelRunningOnNode(modelId)) continue;
                log.info("ML model not in cache. Remove all of its cache files. model id: {}", (Object)modelId);
                this.deleteFileCache(modelId);
            }
        }
    }

    private void deleteFileCache(String modelId) {
        FileUtils.deleteFileQuietly((Path)this.mlEngine.getModelCachePath(modelId));
        FileUtils.deleteFileQuietly((Path)this.mlEngine.getDeployModelPath(modelId));
        FileUtils.deleteFileQuietly((Path)this.mlEngine.getRegisterModelPath(modelId));
    }
}

