/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.securityanalytics.correlation.index.query;

import java.io.IOException;
import java.util.Optional;
import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.Query;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;

public class CorrelationQueryFactory {
    public static Query create(CreateQueryRequest createQueryRequest) {
        String indexName = createQueryRequest.getIndexName();
        String fieldName = createQueryRequest.getFieldName();
        int k = createQueryRequest.getK();
        float[] vector = createQueryRequest.getVector();
        if (createQueryRequest.getFilter().isPresent()) {
            QueryShardContext context = createQueryRequest.getContext().orElseThrow(() -> new RuntimeException("Shard context cannot be null"));
            try {
                Query filterQuery = createQueryRequest.getFilter().get().toQuery(context);
                return new KnnVectorQuery(fieldName, vector, k, filterQuery);
            }
            catch (IOException ex) {
                throw new RuntimeException("Cannot create knn query with filter", ex);
            }
        }
        return new KnnVectorQuery(fieldName, vector, k);
    }

    static class CreateQueryRequest {
        private String indexName;
        private String fieldName;
        private float[] vector;
        private int k;
        private QueryBuilder filter;
        private QueryShardContext context;

        public CreateQueryRequest(String indexName, String fieldName, float[] vector, int k, QueryBuilder filter, QueryShardContext context) {
            this.indexName = indexName;
            this.fieldName = fieldName;
            this.vector = vector;
            this.k = k;
            this.filter = filter;
            this.context = context;
        }

        public String getIndexName() {
            return this.indexName;
        }

        public String getFieldName() {
            return this.fieldName;
        }

        public float[] getVector() {
            return this.vector;
        }

        public int getK() {
            return this.k;
        }

        public Optional<QueryBuilder> getFilter() {
            return Optional.ofNullable(this.filter);
        }

        public Optional<QueryShardContext> getContext() {
            return Optional.ofNullable(this.context);
        }
    }
}

