/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.lops.compile.linearization;

import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.CSVReBlock;
import org.apache.sysds.lops.CentralMoment;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.lops.CoVariance;
import org.apache.sysds.lops.GroupedAggregate;
import org.apache.sysds.lops.GroupedAggregateM;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.lops.MMZip;
import org.apache.sysds.lops.MapMultChain;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.lops.ParameterizedBuiltin;
import org.apache.sysds.lops.PickByCount;
import org.apache.sysds.lops.ReBlock;
import org.apache.sysds.lops.SpoofFused;
import org.apache.sysds.lops.UAggOuterChain;
import org.apache.sysds.lops.UnaryCP;

public interface ILinearize {
    public static final Log LOG = LogFactory.getLog((String)ILinearize.class.getName());

    public static List<Lop> linearize(List<Lop> v) {
        try {
            DagLinearization linearization = ConfigurationManager.getLinearizationOrder();
            switch (linearization) {
                case MAX_PARALLELIZE: {
                    return ILinearize.doMaxParallelizeSort(v);
                }
                case MIN_INTERMEDIATE: {
                    return ILinearize.doMinIntermediateSort(v);
                }
                case BREADTH_FIRST: {
                    return ILinearize.doBreadthFirstSort(v);
                }
            }
            return ILinearize.depthFirst(v);
        }
        catch (Exception e) {
            LOG.warn((Object)("Invalid DAG_LINEARIZATION " + ConfigurationManager.getLinearizationOrder() + ", fallback to DEPTH_FIRST ordering"));
            return ILinearize.depthFirst(v);
        }
    }

    private static List<Lop> depthFirst(List<Lop> v) {
        List<Lop> nodes = Stream.concat(v.stream().filter(l -> !l.getOutputs().isEmpty()).sorted(Comparator.comparing(l -> l.getID())), v.stream().filter(l -> l.getOutputs().isEmpty())).collect(Collectors.toList());
        return nodes;
    }

    private static List<Lop> doBreadthFirstSort(List<Lop> v) {
        List<Lop> nodes = v.stream().sorted(Comparator.comparing(Lop::getLevel)).collect(Collectors.toList());
        return nodes;
    }

    private static List<Lop> doMinIntermediateSort(List<Lop> v) {
        ArrayList<Lop> nodes = new ArrayList<Lop>(v.size());
        List<Lop> lowestLevel = v.stream().filter(l -> l.getOutputs().isEmpty()).collect(Collectors.toList());
        LinkedList<Lop> remaining = new LinkedList<Lop>(v);
        ILinearize.sortRecursive(nodes, lowestLevel, remaining);
        while (!remaining.isEmpty()) {
            int maxLevel = remaining.stream().mapToInt(Lop::getLevel).max().orElse(-1);
            List<Lop> lowestNodes = remaining.stream().filter(l -> l.getLevel() == maxLevel).collect(Collectors.toList());
            ILinearize.sortRecursive(nodes, lowestNodes, remaining);
        }
        Collections.reverse(nodes);
        return nodes;
    }

    private static void sortRecursive(List<Lop> result, List<Lop> input, List<Lop> remaining) {
        List memEst = input.stream().distinct().map(l -> new AbstractMap.SimpleEntry<Lop, Long>((Lop)l, l.getOutputs().isEmpty() ? 0L : OptimizerUtils.estimateSizeExactSparsity(l.getOutputParameters().getNumRows(), l.getOutputParameters().getNumCols(), l.getOutputParameters().getNnz()))).sorted(Comparator.comparing(e -> (Long)e.getValue())).collect(Collectors.toList());
        Collections.reverse(memEst);
        for (Map.Entry e2 : memEst) {
            if (result.contains(e2.getKey()) || !result.containsAll(((Lop)e2.getKey()).getOutputs()) && remaining.stream().anyMatch(l -> ((Lop)e2.getKey()).getOutputs().contains(l))) continue;
            result.add((Lop)e2.getKey());
            remaining.remove(e2.getKey());
            ILinearize.sortRecursive(result, ((Lop)e2.getKey()).getInputs(), remaining);
        }
    }

    private static List<Lop> doMaxParallelizeSort(List<Lop> v) {
        List<Lop> final_v = null;
        if (v.stream().anyMatch(ILinearize::isDistributedOp)) {
            HashMap sparkOpCount = new HashMap();
            List<Lop> roots = v.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
            ArrayList sparkRoots = new ArrayList();
            roots.forEach(r -> OperatorOrderingUtils.collectSparkRoots(r, sparkOpCount, sparkRoots));
            ArrayList<Lop> operatorList = new ArrayList<Lop>();
            sparkRoots.forEach(r -> ILinearize.depthFirst(r, operatorList, sparkOpCount, false));
            roots.forEach(r -> ILinearize.depthFirst(r, operatorList, sparkOpCount, false));
            roots.forEach(Lop::resetVisitStatus);
            final_v = operatorList;
        } else {
            final_v = ILinearize.depthFirst(v);
        }
        return final_v;
    }

    private static void depthFirst(Lop root, ArrayList<Lop> opList, Map<Long, Integer> sparkOpCount, boolean sparkFirst) {
        if (root.isVisited()) {
            return;
        }
        if (root.getInputs().isEmpty()) {
            opList.add(root);
            root.setVisited();
            return;
        }
        Lop[] sortedInputs = root.getInputs().toArray(new Lop[0]);
        if (sparkFirst) {
            Arrays.sort(sortedInputs, (l1, l2) -> (Integer)sparkOpCount.get(l2.getID()) - (Integer)sparkOpCount.get(l1.getID()));
        } else {
            Arrays.sort(sortedInputs, Comparator.comparingInt(l -> (Integer)sparkOpCount.get(l.getID())));
        }
        for (Lop input : sortedInputs) {
            ILinearize.depthFirst(input, opList, sparkOpCount, sparkFirst);
        }
        opList.add(root);
        root.setVisited();
    }

    private static boolean isDistributedOp(Lop lop) {
        return lop.isExecSpark() || lop instanceof UnaryCP && (((UnaryCP)lop).getOpCode().equalsIgnoreCase("prefetch") || ((UnaryCP)lop).getOpCode().equalsIgnoreCase("broadcast"));
    }

    private static List<Lop> addAsyncEagerCheckpointLop(List<Lop> nodes) {
        ArrayList<Lop> nodesWithCheckpoint = new ArrayList<Lop>();
        for (Lop l : nodes) {
            if (ILinearize.isCheckpointNeeded(l)) {
                ArrayList<Lop> oldInputs = new ArrayList<Lop>(l.getInputs());
                for (Lop in : oldInputs) {
                    if (in.getExecType() != Types.ExecType.SPARK) continue;
                    Checkpoint checkpoint = new Checkpoint(in, in.getDataType(), in.getValueType(), Checkpoint.getDefaultStorageLevelString(), true);
                    checkpoint.addOutput(l);
                    l.replaceInput(in, checkpoint);
                    in.removeOutput(l);
                    nodesWithCheckpoint.add(checkpoint);
                }
            }
            nodesWithCheckpoint.add(l);
        }
        return nodesWithCheckpoint;
    }

    private static boolean isCheckpointNeeded(Lop lop) {
        boolean actionOP = lop.getExecType() == Types.ExecType.SPARK && (lop.getAggType() == AggBinaryOp.SparkAggType.SINGLE_BLOCK || lop.getDataType() == Types.DataType.SCALAR || lop instanceof MapMultChain || lop instanceof PickByCount || lop instanceof MMZip || lop instanceof CentralMoment || lop instanceof CoVariance || lop instanceof MMTSJ) && !(lop instanceof Checkpoint) && !(lop instanceof ReBlock) && !(lop instanceof CSVReBlock) && !(lop instanceof UAggOuterChain) && !(lop instanceof ParameterizedBuiltin) && !(lop instanceof SpoofFused);
        boolean hasParameterizedOut = lop.getOutputs().stream().anyMatch(out -> out instanceof ParameterizedBuiltin || out instanceof GroupedAggregate || out instanceof GroupedAggregateM);
        return actionOP && !hasParameterizedOut;
    }

    public static enum DagLinearization {
        DEPTH_FIRST,
        BREADTH_FIRST,
        MIN_INTERMEDIATE,
        MAX_PARALLELIZE;

    }
}

