/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Predicate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.core.Strings;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentService;

public class InferenceWaitForAllocation {
    public static final int MAX_PENDING_REQUEST_COUNT = 100;
    private static final Logger logger = LogManager.getLogger(InferenceWaitForAllocation.class);
    private final TrainedModelAssignmentService assignmentService;
    private final BiConsumer<WaitingRequest, TrainedModelAssignment> queuedConsumer;
    private AtomicInteger pendingRequestCount = new AtomicInteger();

    public InferenceWaitForAllocation(TrainedModelAssignmentService assignmentService, BiConsumer<WaitingRequest, TrainedModelAssignment> onInferenceScaledConsumer) {
        this.assignmentService = assignmentService;
        this.queuedConsumer = onInferenceScaledConsumer;
    }

    public synchronized void waitForAssignment(WaitingRequest request) {
        if (this.pendingRequestCount.incrementAndGet() >= 100) {
            this.pendingRequestCount.decrementAndGet();
            request.listener.onFailure((Exception)new ElasticsearchStatusException("Rejected inference request waiting for an allocation of deployment [{}]. Too many pending requests", RestStatus.TOO_MANY_REQUESTS, new Object[]{request.request.getId()}));
            return;
        }
        DeploymentHasAtLeastOneAllocation predicate = new DeploymentHasAtLeastOneAllocation(request.deploymentId());
        this.assignmentService.waitForAssignmentCondition(request.deploymentId(), predicate, request.request().getInferenceTimeout(), new WaitingListener(request, predicate));
    }

    public record WaitingRequest(InferModelAction.Request request, InferModelAction.Response.Builder responseBuilder, TaskId parentTaskId, ActionListener<InferModelAction.Response> listener) {
        public String deploymentId() {
            return this.request.getId();
        }
    }

    private static class DeploymentHasAtLeastOneAllocation
    implements Predicate<ClusterState> {
        private final String deploymentId;
        private AtomicReference<Exception> exception = new AtomicReference();

        DeploymentHasAtLeastOneAllocation(String deploymentId) {
            this.deploymentId = (String)ExceptionsHelper.requireNonNull((Object)deploymentId, (String)"deployment_id");
        }

        @Override
        public boolean test(ClusterState clusterState) {
            TrainedModelAssignment trainedModelAssignment = TrainedModelAssignmentMetadata.assignmentForDeploymentId((ClusterState)clusterState, (String)this.deploymentId).orElse(null);
            if (trainedModelAssignment == null) {
                logger.info(() -> Strings.format((String)"[%s] assignment was null while waiting to scale up", (Object[])new Object[]{this.deploymentId}));
                this.exception.set((Exception)new ElasticsearchStatusException("[{}] Error waiting for a model allocation, model assignment has been removed", RestStatus.CONFLICT, new Object[]{this.deploymentId}));
                return true;
            }
            HashMap<String, String> nodeFailuresAndReasons = new HashMap<String, String>();
            for (Map.Entry nodeIdAndRouting : trainedModelAssignment.getNodeRoutingTable().entrySet()) {
                if (!RoutingState.FAILED.equals((Object)((RoutingInfo)nodeIdAndRouting.getValue()).getState())) continue;
                nodeFailuresAndReasons.put((String)nodeIdAndRouting.getKey(), ((RoutingInfo)nodeIdAndRouting.getValue()).getReason());
            }
            if (!nodeFailuresAndReasons.isEmpty()) {
                if (nodeFailuresAndReasons.size() == trainedModelAssignment.getNodeRoutingTable().size()) {
                    this.exception.set((Exception)new ElasticsearchStatusException("[{}] Error waiting for a model allocation, all nodes have failed with errors [{}]", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{trainedModelAssignment.getDeploymentId(), nodeFailuresAndReasons}));
                    return true;
                }
                logger.warn("Deployment [{}] has failed routes [{}]", (Object)trainedModelAssignment.getDeploymentId(), nodeFailuresAndReasons);
            }
            Optional<RoutingInfo> routable = trainedModelAssignment.getNodeRoutingTable().values().stream().filter(RoutingInfo::isRoutable).findFirst();
            return routable.isPresent();
        }
    }

    private class WaitingListener
    implements TrainedModelAssignmentService.WaitForAssignmentListener {
        private final WaitingRequest request;
        private final DeploymentHasAtLeastOneAllocation predicate;

        private WaitingListener(WaitingRequest request, DeploymentHasAtLeastOneAllocation predicate) {
            this.request = request;
            this.predicate = predicate;
        }

        public void onResponse(TrainedModelAssignment assignment) {
            InferenceWaitForAllocation.this.pendingRequestCount.decrementAndGet();
            if (this.predicate.exception.get() != null) {
                this.onFailure(this.predicate.exception.get());
                return;
            }
            InferenceWaitForAllocation.this.queuedConsumer.accept(this.request, assignment);
        }

        public void onFailure(Exception e) {
            InferenceWaitForAllocation.this.pendingRequestCount.decrementAndGet();
            this.request.listener().onFailure(e);
        }
    }
}

