/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.encode;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.frame.data.columns.Array;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CompressedEncode {
    protected static final Log LOG = LogFactory.getLog((String)CompressedEncode.class.getName());
    private final MultiColumnEncoder enc;
    private final FrameBlock in;
    private final int k;

    private CompressedEncode(MultiColumnEncoder enc, FrameBlock in, int k) {
        this.enc = enc;
        this.in = in;
        this.k = k;
    }

    public static MatrixBlock encode(MultiColumnEncoder enc, FrameBlock in, int k) {
        return new CompressedEncode(enc, in, k).apply();
    }

    private MatrixBlock apply() {
        List<ColumnEncoderComposite> encoders = this.enc.getColumnEncoders();
        List<AColGroup> groups = this.isParallel() ? this.multiThread(encoders) : this.singleThread(encoders);
        int cols = this.shiftGroups(groups);
        CompressedMatrixBlock mb = new CompressedMatrixBlock(this.in.getNumRows(), cols, -1L, false, groups);
        ((MatrixBlock)mb).recomputeNonZeros();
        this.logging(mb);
        return mb;
    }

    private boolean isParallel() {
        return this.k > 1 && this.enc.getEncoders().size() > 1;
    }

    private List<AColGroup> singleThread(List<ColumnEncoderComposite> encoders) {
        ArrayList<AColGroup> groups = new ArrayList<AColGroup>(encoders.size());
        for (ColumnEncoderComposite c : encoders) {
            groups.add(this.encode(c));
        }
        return groups;
    }

    private List<AColGroup> multiThread(List<ColumnEncoderComposite> encoders) {
        ExecutorService pool = CommonThreadPool.get(this.k);
        try {
            ArrayList<EncodeTask> tasks = new ArrayList<EncodeTask>(encoders.size());
            for (ColumnEncoderComposite c : encoders) {
                tasks.add(new EncodeTask(c));
            }
            ArrayList<AColGroup> groups = new ArrayList<AColGroup>(encoders.size());
            for (Future t : pool.invokeAll(tasks)) {
                groups.add((AColGroup)t.get());
            }
            pool.shutdown();
            return groups;
        }
        catch (InterruptedException | ExecutionException ex) {
            pool.shutdown();
            throw new DMLRuntimeException("Failed parallel compressed transform encode", ex);
        }
    }

    private int shiftGroups(List<AColGroup> groups) {
        int cols = groups.get(0).getColIndices().size();
        for (int i = 1; i < groups.size(); ++i) {
            groups.set(i, groups.get(i).shiftColIndices(cols));
            cols += groups.get(i).getColIndices().size();
        }
        return cols;
    }

    private AColGroup encode(ColumnEncoderComposite c) {
        if (c.isRecodeToDummy()) {
            return this.recodeToDummy(c);
        }
        if (c.isRecode()) {
            return this.recode(c);
        }
        if (c.isPassThrough()) {
            return this.passThrough(c);
        }
        throw new NotImplementedException("Not supporting : " + c);
    }

    private AColGroup recodeToDummy(ColumnEncoderComposite c) {
        int colId = c._colID;
        Array<?> a = this.in.getColumn(colId - 1);
        HashMap<Object, Long> map = a.getRecodeMap();
        int domain = map.size();
        IColIndex colIndexes = ColIndexFactory.create(0, domain);
        IdentityDictionary d = new IdentityDictionary(colIndexes.size());
        AMapToData m = this.createMappingAMapToData(a, map);
        List<ColumnEncoder> r = c.getEncoders();
        r.set(0, new ColumnEncoderRecode(colId, map));
        return ColGroupDDC.create(colIndexes, d, m, null);
    }

    private AColGroup recode(ColumnEncoderComposite c) {
        int colId = c._colID;
        Array<?> a = this.in.getColumn(colId - 1);
        HashMap<Object, Long> map = a.getRecodeMap();
        int domain = map.size();
        IColIndex colIndexes = ColIndexFactory.create(1);
        MatrixBlock incrementing = new MatrixBlock(domain, 1, false);
        for (int i = 0; i < domain; ++i) {
            incrementing.quickSetValue(i, 0, i + 1);
        }
        MatrixBlockDictionary d = MatrixBlockDictionary.create(incrementing);
        AMapToData m = this.createMappingAMapToData(a, map);
        List<ColumnEncoder> r = c.getEncoders();
        r.set(0, new ColumnEncoderRecode(colId, map));
        return ColGroupDDC.create(colIndexes, d, m, null);
    }

    private AColGroup passThrough(ColumnEncoderComposite c) {
        IColIndex colIndexes = ColIndexFactory.create(1);
        int colId = c._colID;
        Array<?> a = this.in.getColumn(colId - 1);
        HashMap<?, Long> map = a.getRecodeMap();
        int blockSz = ConfigurationManager.getDMLConfig().getIntValue("sysds.defaultblocksize");
        if (map.size() >= blockSz) {
            double[] vals = (double[])a.changeType(Types.ValueType.FP64).get();
            MatrixBlock col = new MatrixBlock(a.size(), 1, vals);
            col.recomputeNonZeros();
            return ColGroupUncompressed.create(colIndexes, col, false);
        }
        double[] vals = new double[map.size() + (a.containsNull() ? 1 : 0)];
        for (int i = 0; i < a.size(); ++i) {
            Object v = a.get(i);
            if (map.containsKey(v)) {
                vals[map.get(v).intValue()] = a.getAsDouble(i);
                continue;
            }
            map.put(null, Long.valueOf(map.size()));
            vals[map.get(v).intValue()] = a.getAsDouble(i);
        }
        Dictionary d = Dictionary.create(vals);
        AMapToData m = this.createMappingAMapToData(a, map);
        return ColGroupDDC.create(colIndexes, d, m, null);
    }

    private AMapToData createMappingAMapToData(Array<?> a, HashMap<?, Long> map) {
        AMapToData m = MapToFactory.create(this.in.getNumRows(), map.size());
        Array.ArrayIterator it = a.getIterator();
        while (it.hasNext()) {
            Object v = it.next();
            if (v == null) continue;
            m.set(it.getIndex(), map.get(v).intValue());
        }
        return m;
    }

    private void logging(MatrixBlock mb) {
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)String.format("Uncompressed transform encode Dense size:   %16d", mb.estimateSizeDenseInMemory()));
            LOG.debug((Object)String.format("Uncompressed transform encode Sparse size:  %16d", mb.estimateSizeSparseInMemory()));
            LOG.debug((Object)String.format("Compressed transform encode size:           %16d", mb.estimateSizeInMemory()));
            double ratio = Math.min(mb.estimateSizeDenseInMemory(), mb.estimateSizeSparseInMemory()) / mb.estimateSizeInMemory();
            double denseRatio = mb.estimateSizeDenseInMemory() / mb.estimateSizeInMemory();
            LOG.debug((Object)String.format("Compression ratio: %10.3f", ratio));
            LOG.debug((Object)String.format("Dense ratio:       %10.3f", denseRatio));
        }
    }

    private class EncodeTask
    implements Callable<AColGroup> {
        ColumnEncoderComposite c;

        protected EncodeTask(ColumnEncoderComposite c) {
            this.c = c;
        }

        @Override
        public AColGroup call() throws Exception {
            return CompressedEncode.this.encode(this.c);
        }
    }
}

