/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.physical.local;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.BinarySpatialFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesFunction;
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.EvalExec;
import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec;
import org.elasticsearch.xpack.esql.plan.physical.FilterExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
import org.elasticsearch.xpack.esql.stats.SearchStats;

public class SpatialDocValuesExtraction
extends PhysicalOptimizerRules.ParameterizedOptimizerRule<AggregateExec, LocalPhysicalOptimizerContext> {
    @Override
    protected PhysicalPlan rule(AggregateExec aggregate, LocalPhysicalOptimizerContext ctx) {
        HashSet foundAttributes = new HashSet();
        PhysicalPlan plan = (PhysicalPlan)aggregate.transformDown(UnaryExec.class, exec -> {
            EvalExec evalExec;
            List<Alias> fields;
            List<Alias> changed;
            if (exec instanceof AggregateExec) {
                AggregateExec agg = (AggregateExec)exec;
                ArrayList<Object> orderedAggregates = new ArrayList<Object>();
                boolean changedAggregates = false;
                for (NamedExpression namedExpression : agg.aggregates()) {
                    Alias as;
                    Expression patt6023$temp;
                    if (namedExpression instanceof Alias && (patt6023$temp = (as = (Alias)namedExpression).child()) instanceof SpatialAggregateFunction) {
                        FieldAttribute fieldAttribute;
                        SpatialAggregateFunction af = (SpatialAggregateFunction)patt6023$temp;
                        Expression patt6104$temp = af.field();
                        if (patt6104$temp instanceof FieldAttribute && this.allowedForDocValues(fieldAttribute = (FieldAttribute)patt6104$temp, ctx.searchStats(), agg, foundAttributes)) {
                            foundAttributes.add(fieldAttribute);
                            changedAggregates = true;
                            orderedAggregates.add(as.replaceChild((Expression)af.withFieldExtractPreference(MappedFieldType.FieldExtractPreference.DOC_VALUES)));
                            continue;
                        }
                        orderedAggregates.add(namedExpression);
                        continue;
                    }
                    orderedAggregates.add(namedExpression);
                }
                if (changedAggregates) {
                    exec = new AggregateExec(agg.source(), agg.child(), agg.groupings(), orderedAggregates, agg.getMode(), agg.intermediateAttributes(), agg.estimatedRowSize());
                }
            }
            if (exec instanceof EvalExec && !(changed = (fields = (evalExec = (EvalExec)exec).fields()).stream().map(f -> (Alias)f.transformDown(BinarySpatialFunction.class, s -> this.withDocValues((BinarySpatialFunction)s, foundAttributes))).toList()).equals(fields)) {
                exec = new EvalExec(exec.source(), exec.child(), changed);
            }
            if (exec instanceof FilterExec) {
                FilterExec filterExec = (FilterExec)((Object)exec);
                Expression condition = (Expression)filterExec.condition().transformDown(BinarySpatialFunction.class, s -> this.withDocValues((BinarySpatialFunction)s, foundAttributes));
                if (!filterExec.condition().equals((Object)condition)) {
                    exec = new FilterExec(filterExec.source(), filterExec.child(), condition);
                }
            }
            if (exec instanceof FieldExtractExec) {
                FieldExtractExec fieldExtractExec = (FieldExtractExec)exec;
                List<Attribute> attributesToExtract = fieldExtractExec.attributesToExtract();
                HashSet<Attribute> docValuesAttributes = new HashSet<Attribute>();
                for (Attribute attribute : foundAttributes) {
                    if (!attributesToExtract.contains(attribute)) continue;
                    docValuesAttributes.add(attribute);
                }
                if (!docValuesAttributes.isEmpty()) {
                    exec = fieldExtractExec.withDocValuesAttributes(docValuesAttributes);
                }
            }
            return exec;
        });
        return plan;
    }

    private BinarySpatialFunction withDocValues(BinarySpatialFunction spatial, Set<FieldAttribute> foundAttributes) {
        boolean foundLeft = this.foundField(spatial.left(), foundAttributes);
        boolean foundRight = this.foundField(spatial.right(), foundAttributes);
        return foundLeft || foundRight ? spatial.withDocValues(foundLeft, foundRight) : spatial;
    }

    private boolean hasFieldAttribute(BinarySpatialFunction spatial, Set<FieldAttribute> foundAttributes) {
        return this.foundField(spatial.left(), foundAttributes) || this.foundField(spatial.right(), foundAttributes);
    }

    private boolean foundField(Expression expression, Set<FieldAttribute> foundAttributes) {
        FieldAttribute field;
        return expression instanceof FieldAttribute && foundAttributes.contains(field = (FieldAttribute)expression);
    }

    private boolean allowedForDocValues(FieldAttribute fieldAttribute, SearchStats stats, AggregateExec agg, Set<FieldAttribute> foundAttributes) {
        if (!stats.hasDocValues(fieldAttribute.fieldName())) {
            return false;
        }
        if (fieldAttribute.dataType() == DataType.GEO_SHAPE || fieldAttribute.dataType() == DataType.CARTESIAN_SHAPE) {
            return false;
        }
        HashSet<FieldAttribute> candidateDocValuesAttributes = new HashSet<FieldAttribute>(foundAttributes);
        candidateDocValuesAttributes.add(fieldAttribute);
        HashSet spatialRelatesAttributes = new HashSet();
        agg.forEachExpressionDown(SpatialRelatesFunction.class, relatesFunction -> candidateDocValuesAttributes.forEach(candidate -> {
            if (this.hasFieldAttribute((BinarySpatialFunction)relatesFunction, Set.of(candidate))) {
                spatialRelatesAttributes.add(candidate);
            }
        }));
        return spatialRelatesAttributes.size() < 2;
    }
}

