/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.helper;

import com.google.common.collect.ImmutableList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.join.ScoreMode;
import org.opensearch.action.get.GetRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.ExistsQueryBuilder;
import org.opensearch.index.query.IdsQueryBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.MatchPhraseQueryBuilder;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.NestedQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.MLModelGroup;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.search.builder.SearchSourceBuilder;

public class ModelAccessControlHelper {
    @Generated
    private static final Logger log = LogManager.getLogger(ModelAccessControlHelper.class);
    private volatile Boolean modelAccessControlEnabled;
    private static final List<Class<?>> SUPPORTED_QUERY_TYPES = ImmutableList.of(IdsQueryBuilder.class, MatchQueryBuilder.class, MatchAllQueryBuilder.class, MatchPhraseQueryBuilder.class, TermQueryBuilder.class, TermsQueryBuilder.class, ExistsQueryBuilder.class, RangeQueryBuilder.class);

    public ModelAccessControlHelper(ClusterService clusterService, Settings settings) {
        this.modelAccessControlEnabled = (Boolean)MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED, it -> {
            this.modelAccessControlEnabled = it;
        });
    }

    public void validateModelGroupAccess(User user, String modelGroupId, Client client, ActionListener<Boolean> listener) {
        if (modelGroupId == null || this.isAdmin(user) || !this.isSecurityEnabledAndModelAccessControlEnabled(user)) {
            listener.onResponse((Object)true);
            return;
        }
        List userBackendRoles = user.getBackendRoles();
        GetRequest getModelGroupRequest = new GetRequest(".plugins-ml-model-group").id(modelGroupId);
        try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext();){
            ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
            client.get(getModelGroupRequest, ActionListener.wrap(r -> {
                if (r != null && r.isExists()) {
                    try (XContentParser parser = MLNodeUtils.createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, r.getSourceAsBytesRef());){
                        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.nextToken(), (XContentParser)parser);
                        MLModelGroup mlModelGroup = MLModelGroup.parse((XContentParser)parser);
                        AccessMode modelAccessMode = AccessMode.from((String)mlModelGroup.getAccess());
                        if (mlModelGroup.getOwner() == null) {
                            wrappedListener.onResponse((Object)true);
                        } else if (AccessMode.RESTRICTED == modelAccessMode) {
                            if (mlModelGroup.getBackendRoles() == null || mlModelGroup.getBackendRoles().size() == 0) {
                                throw new IllegalStateException("Backend roles shouldn't be null");
                            }
                            wrappedListener.onResponse((Object)Optional.ofNullable(userBackendRoles).orElse((List)ImmutableList.of()).stream().anyMatch(mlModelGroup.getBackendRoles()::contains));
                        } else if (AccessMode.PUBLIC == modelAccessMode) {
                            wrappedListener.onResponse((Object)true);
                        } else if (AccessMode.PRIVATE == modelAccessMode) {
                            if (this.isOwner(mlModelGroup.getOwner(), user)) {
                                wrappedListener.onResponse((Object)true);
                            } else {
                                wrappedListener.onResponse((Object)false);
                            }
                        }
                    }
                    catch (Exception e) {
                        log.error("Failed to parse ml model group");
                        wrappedListener.onFailure(e);
                    }
                } else {
                    wrappedListener.onFailure((Exception)new MLResourceNotFoundException("Fail to find model group"));
                }
            }, e -> {
                if (e instanceof IndexNotFoundException) {
                    wrappedListener.onFailure((Exception)new MLResourceNotFoundException("Fail to find model group"));
                } else {
                    log.error("Fail to get model group", (Throwable)e);
                    wrappedListener.onFailure((Exception)new MLValidationException("Fail to get model group"));
                }
            }));
        }
        catch (Exception e2) {
            log.error("Failed to validate Access", (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    public boolean skipModelAccessControl(User user) {
        return user == null || this.modelAccessControlEnabled == false || this.isAdmin(user);
    }

    public boolean isSecurityEnabledAndModelAccessControlEnabled(User user) {
        return user != null && this.modelAccessControlEnabled != false;
    }

    public boolean isAdmin(User user) {
        if (user == null) {
            return false;
        }
        if (CollectionUtils.isEmpty((Collection)user.getRoles())) {
            return false;
        }
        return user.getRoles().contains("all_access");
    }

    public boolean isOwner(User owner, User user) {
        if (user == null || owner == null) {
            return false;
        }
        return owner.getName().equals(user.getName());
    }

    public boolean isUserHasBackendRole(User user, MLModelGroup mlModelGroup) {
        AccessMode modelAccessMode = AccessMode.from((String)mlModelGroup.getAccess());
        if (AccessMode.PUBLIC == modelAccessMode) {
            return true;
        }
        if (AccessMode.PRIVATE == modelAccessMode) {
            return false;
        }
        return user.getBackendRoles() != null && mlModelGroup.getBackendRoles() != null && mlModelGroup.getBackendRoles().stream().anyMatch(x -> user.getBackendRoles().contains(x));
    }

    public boolean isOwnerStillHasPermission(User user, MLModelGroup mlModelGroup) {
        if (!this.isSecurityEnabledAndModelAccessControlEnabled(user)) {
            return true;
        }
        AccessMode access = AccessMode.from((String)mlModelGroup.getAccess());
        if (AccessMode.PUBLIC == access) {
            return true;
        }
        if (AccessMode.PRIVATE == access) {
            return this.isOwner(user, mlModelGroup.getOwner());
        }
        if (AccessMode.RESTRICTED == access) {
            if (CollectionUtils.isEmpty((Collection)mlModelGroup.getBackendRoles())) {
                throw new IllegalStateException("Backend roles should not be null");
            }
            return user.getBackendRoles() != null && new HashSet(mlModelGroup.getBackendRoles()).stream().anyMatch(x -> user.getBackendRoles().contains(x));
        }
        throw new IllegalStateException("Access shouldn't be null");
    }

    public boolean isModelAccessControlEnabled() {
        return this.modelAccessControlEnabled;
    }

    public SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSourceBuilder searchSourceBuilder) {
        BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
        boolQueryBuilder.should((QueryBuilder)QueryBuilders.termQuery((String)"access", (String)AccessMode.PUBLIC.getValue()));
        boolQueryBuilder.should((QueryBuilder)QueryBuilders.termsQuery((String)"backend_roles.keyword", (Collection)user.getBackendRoles()));
        BoolQueryBuilder privateBoolQuery = new BoolQueryBuilder();
        String ownerName = "owner.name.keyword";
        TermQueryBuilder ownerNameTermQuery = QueryBuilders.termQuery((String)ownerName, (String)user.getName());
        NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder("owner", (QueryBuilder)ownerNameTermQuery, ScoreMode.None);
        privateBoolQuery.must((QueryBuilder)nestedQueryBuilder);
        privateBoolQuery.must((QueryBuilder)QueryBuilders.termQuery((String)"access", (String)AccessMode.PRIVATE.getValue()));
        boolQueryBuilder.should((QueryBuilder)privateBoolQuery);
        QueryBuilder query = searchSourceBuilder.query();
        if (query == null) {
            searchSourceBuilder.query((QueryBuilder)boolQueryBuilder);
        } else if (query instanceof BoolQueryBuilder) {
            ((BoolQueryBuilder)query).filter((QueryBuilder)boolQueryBuilder);
        } else {
            BoolQueryBuilder rewriteQuery = new BoolQueryBuilder();
            rewriteQuery.must(query);
            rewriteQuery.filter((QueryBuilder)boolQueryBuilder);
            searchSourceBuilder.query((QueryBuilder)rewriteQuery);
        }
        return searchSourceBuilder;
    }

    public SearchSourceBuilder createSearchSourceBuilder(User user) {
        return this.addUserBackendRolesFilter(user, new SearchSourceBuilder());
    }
}

