/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.classification.knn;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.classification.knn.KnnModel;
import org.apache.flink.ml.classification.knn.KnnModelData;
import org.apache.flink.ml.classification.knn.KnnParams;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseMatrix;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

public class Knn
implements Estimator<Knn, KnnModel>,
KnnParams<Knn> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public Knn() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override
    public KnnModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        DataStream<Tuple3<DenseVector, Double, Double>> inputDataWithNorm = this.computeNormSquare((DataStream<Row>)tEnv.toDataStream(inputs[0]));
        DataStream<KnnModelData> modelData = Knn.genModelData(inputDataWithNorm);
        KnnModel model = new KnnModel().setModelData(tEnv.fromDataStream(modelData));
        ParamUtils.updateExistingParams(model, this.getParamMap());
        return model;
    }

    @Override
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    @Override
    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata(this, path);
    }

    public static Knn load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (Knn)ReadWriteUtils.loadStageParam(path);
    }

    private static DataStream<KnnModelData> genModelData(DataStream<Tuple3<DenseVector, Double, Double>> inputDataWithNormSqare) {
        DataStream modelData = DataStreamUtils.mapPartition(inputDataWithNormSqare, new RichMapPartitionFunction<Tuple3<DenseVector, Double, Double>, KnnModelData>(){

            public void mapPartition(Iterable<Tuple3<DenseVector, Double, Double>> dataPoints, Collector<KnnModelData> out) {
                ArrayList<Tuple3<DenseVector, Double, Double>> bufferedDataPoints = new ArrayList<Tuple3<DenseVector, Double, Double>>();
                for (Tuple3<DenseVector, Double, Double> dataPoint : dataPoints) {
                    bufferedDataPoints.add(dataPoint);
                }
                int featureDim = ((DenseVector)((Tuple3)bufferedDataPoints.get((int)0)).f0).size();
                DenseMatrix packedFeatures = new DenseMatrix(featureDim, bufferedDataPoints.size());
                DenseVector normSquares = new DenseVector(bufferedDataPoints.size());
                DenseVector labels = new DenseVector(bufferedDataPoints.size());
                int offset = 0;
                for (Tuple3 tuple3 : bufferedDataPoints) {
                    System.arraycopy(((DenseVector)tuple3.f0).values, 0, packedFeatures.values, offset * featureDim, featureDim);
                    labels.values[offset] = (Double)tuple3.f1;
                    normSquares.values[offset++] = (Double)tuple3.f2;
                }
                out.collect((Object)new KnnModelData(packedFeatures, normSquares, labels));
            }
        });
        modelData.getTransformation().setParallelism(1);
        return modelData;
    }

    private DataStream<Tuple3<DenseVector, Double, Double>> computeNormSquare(DataStream<Row> inputData) {
        return inputData.map((MapFunction)new MapFunction<Row, Tuple3<DenseVector, Double, Double>>(){

            public Tuple3<DenseVector, Double, Double> map(Row value) {
                Double label = ((Number)value.getField(Knn.this.getLabelCol())).doubleValue();
                DenseVector feature = ((Vector)value.getField(Knn.this.getFeaturesCol())).toDense();
                return Tuple3.of((Object)feature, (Object)label, (Object)Math.pow(BLAS.norm2((Vector)feature), 2.0));
            }
        });
    }
}

