/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.clustering.example;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.tribuo.ConfigurableDataSource;
import org.tribuo.DataSource;
import org.tribuo.Example;
import org.tribuo.MutableDataset;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ClusteringFactory;
import org.tribuo.impl.ArrayExample;
import org.tribuo.provenance.ConfiguredDataSourceProvenance;
import org.tribuo.provenance.DataSourceProvenance;
import org.tribuo.util.Util;

public final class GaussianClusterDataSource
implements ConfigurableDataSource<ClusterID> {
    private static final ClusteringFactory factory = new ClusteringFactory();
    private static final String[] allFeatureNames = new String[]{"A", "B", "C", "D"};
    @Config(mandatory=true, description="The number of samples to draw.")
    private int numSamples;
    @Config(description="The probability of sampling from each Gaussian, must sum to 1.0.")
    private double[] mixingDistribution = new double[]{0.1, 0.35, 0.05, 0.25, 0.25};
    @Config(description="The mean of the first Gaussian.")
    private double[] firstMean = new double[]{0.0, 0.0};
    @Config(description="A vector representing the first Gaussian's covariance matrix.")
    private double[] firstVariance = new double[]{1.0, 0.0, 0.0, 1.0};
    @Config(description="The mean of the second Gaussian.")
    private double[] secondMean = new double[]{5.0, 5.0};
    @Config(description="A vector representing the second Gaussian's covariance matrix.")
    private double[] secondVariance = new double[]{1.0, 0.0, 0.0, 1.0};
    @Config(description="The mean of the third Gaussian.")
    private double[] thirdMean = new double[]{2.5, 2.5};
    @Config(description="A vector representing the third Gaussian's covariance matrix.")
    private double[] thirdVariance = new double[]{1.0, 0.5, 0.5, 1.0};
    @Config(description="The mean of the fourth Gaussian.")
    private double[] fourthMean = new double[]{10.0, 0.0};
    @Config(description="A vector representing the fourth Gaussian's covariance matrix.")
    private double[] fourthVariance = new double[]{0.1, 0.0, 0.0, 0.1};
    @Config(description="The mean of the fifth Gaussian.")
    private double[] fifthMean = new double[]{-1.0, 0.0};
    @Config(description="A vector representing the fifth Gaussian's covariance matrix.")
    private double[] fifthVariance = new double[]{1.0, 0.0, 0.0, 0.1};
    @Config(description="The RNG seed.")
    private long seed = 12345L;
    private List<Example<ClusterID>> examples;

    private GaussianClusterDataSource() {
    }

    public GaussianClusterDataSource(int numSamples, long seed) {
        this.numSamples = numSamples;
        this.seed = seed;
        this.postConfig();
    }

    public GaussianClusterDataSource(int numSamples, double[] mixingDistribution, double[] firstMean, double[] firstVariance, double[] secondMean, double[] secondVariance, double[] thirdMean, double[] thirdVariance, double[] fourthMean, double[] fourthVariance, double[] fifthMean, double[] fifthVariance, long seed) {
        this.numSamples = numSamples;
        this.mixingDistribution = mixingDistribution;
        this.firstMean = firstMean;
        this.firstVariance = firstVariance;
        this.secondMean = secondMean;
        this.secondVariance = secondVariance;
        this.thirdMean = thirdMean;
        this.thirdVariance = thirdVariance;
        this.fourthMean = fourthMean;
        this.fourthVariance = fourthVariance;
        this.fifthMean = fifthMean;
        this.fifthVariance = fifthVariance;
        this.seed = seed;
        this.postConfig();
    }

    public void postConfig() {
        if (this.numSamples < 1) {
            throw new PropertyException("", "numSamples", "numSamples must be positive, found " + this.numSamples);
        }
        if (this.mixingDistribution.length != 5) {
            throw new PropertyException("", "mixingDistribution", "mixingDistribution must have 5 elements, found " + this.mixingDistribution.length);
        }
        if (Math.abs(Util.sum((double[])this.mixingDistribution) - 1.0) > 1.0E-10) {
            throw new PropertyException("", "mixingDistribution", "mixingDistribution must sum to 1.0, found " + Util.sum((double[])this.mixingDistribution));
        }
        if (this.firstMean.length > allFeatureNames.length || this.firstMean.length == 0) {
            throw new PropertyException("", "firstMean", "Must have 1-4 features, found " + this.firstMean.length);
        }
        int covarianceSize = this.firstMean.length * this.firstMean.length;
        if (this.firstVariance.length != covarianceSize) {
            throw new PropertyException("", "firstVariance", "Invalid first covariance matrix, expected " + covarianceSize + " elements, found " + this.firstVariance.length);
        }
        if (this.secondMean.length != this.firstMean.length) {
            throw new PropertyException("", "secondMean", "All Gaussians must have the same number of dimensions, expected " + this.firstMean.length + ", found " + this.secondMean.length);
        }
        if (this.secondVariance.length != this.firstVariance.length) {
            throw new PropertyException("", "secondVariance", "secondVariance is invalid, expected " + covarianceSize + ", found " + this.secondVariance.length);
        }
        if (this.thirdMean.length != this.firstMean.length) {
            throw new PropertyException("", "thirdMean", "All Gaussians must have the same number of dimensions, expected " + this.firstMean.length + ", found " + this.thirdMean.length);
        }
        if (this.thirdVariance.length != this.firstVariance.length) {
            throw new PropertyException("", "thirdVariance", "thirdVariance is invalid, expected " + covarianceSize + ", found " + this.thirdVariance.length);
        }
        if (this.fourthMean.length != this.firstMean.length) {
            throw new PropertyException("", "fourthMean", "All Gaussians must have the same number of dimensions, expected " + this.firstMean.length + ", found " + this.fourthMean.length);
        }
        if (this.fourthVariance.length != this.firstVariance.length) {
            throw new PropertyException("", "fourthVariance", "fourthVariance is invalid, expected " + covarianceSize + ", found " + this.fourthVariance.length);
        }
        if (this.fifthMean.length != this.firstMean.length) {
            throw new PropertyException("", "fifthMean", "All Gaussians must have the same number of dimensions, expected " + this.firstMean.length + ", found " + this.fifthMean.length);
        }
        if (this.fifthVariance.length != this.firstVariance.length) {
            throw new PropertyException("", "fifthVariance", "fifthVariance is invalid, expected " + covarianceSize + ", found " + this.fifthVariance.length);
        }
        for (int i = 0; i < this.mixingDistribution.length; ++i) {
            if (!(this.mixingDistribution[i] < 0.0)) continue;
            throw new PropertyException("", "mixingDistribution", "Probability values in the mixing distribution must be non-negative, found " + Arrays.toString(this.mixingDistribution));
        }
        double[] mixingCDF = Util.generateCDF((double[])this.mixingDistribution);
        String[] featureNames = Arrays.copyOf(allFeatureNames, this.firstMean.length);
        Random rng = new Random(this.seed);
        MultivariateNormalDistribution first = new MultivariateNormalDistribution((RandomGenerator)new JDKRandomGenerator(rng.nextInt()), this.firstMean, GaussianClusterDataSource.reshapeAndValidate(this.firstVariance, "firstVariance"));
        MultivariateNormalDistribution second = new MultivariateNormalDistribution((RandomGenerator)new JDKRandomGenerator(rng.nextInt()), this.secondMean, GaussianClusterDataSource.reshapeAndValidate(this.secondVariance, "secondVariance"));
        MultivariateNormalDistribution third = new MultivariateNormalDistribution((RandomGenerator)new JDKRandomGenerator(rng.nextInt()), this.thirdMean, GaussianClusterDataSource.reshapeAndValidate(this.thirdVariance, "thirdVariance"));
        MultivariateNormalDistribution fourth = new MultivariateNormalDistribution((RandomGenerator)new JDKRandomGenerator(rng.nextInt()), this.fourthMean, GaussianClusterDataSource.reshapeAndValidate(this.fourthVariance, "fourthVariance"));
        MultivariateNormalDistribution fifth = new MultivariateNormalDistribution((RandomGenerator)new JDKRandomGenerator(rng.nextInt()), this.fifthMean, GaussianClusterDataSource.reshapeAndValidate(this.fifthVariance, "fifthVariance"));
        MultivariateNormalDistribution[] Gaussians = new MultivariateNormalDistribution[]{first, second, third, fourth, fifth};
        ArrayList<ArrayExample> examples = new ArrayList<ArrayExample>(this.numSamples);
        for (int i = 0; i < this.numSamples; ++i) {
            int centroid = Util.sampleFromCDF((double[])mixingCDF, (Random)rng);
            double[] sample = Gaussians[centroid].sample();
            examples.add(new ArrayExample((Output)new ClusterID(centroid), featureNames, sample));
        }
        this.examples = Collections.unmodifiableList(examples);
    }

    public OutputFactory<ClusterID> getOutputFactory() {
        return factory;
    }

    public DataSourceProvenance getProvenance() {
        return new GaussianClusterDataSourceProvenance(this);
    }

    public Iterator<Example<ClusterID>> iterator() {
        return this.examples.iterator();
    }

    private static double[][] reshapeAndValidate(double[] vector, String fieldName) {
        int length = (int)Math.sqrt(vector.length);
        if (length * length != vector.length) {
            throw new IllegalArgumentException("The vector does not represent a square matrix, found " + vector.length + " elements, which is not square.");
        }
        double[][] matrix = new double[length][length];
        for (int i = 0; i < vector.length; ++i) {
            if (vector[i] < 0.0) {
                throw new PropertyException("", fieldName, fieldName + " must have a non-negative covariance matrix, found " + Arrays.toString(vector));
            }
            matrix[i / length][i % length] = vector[i];
        }
        return matrix;
    }

    public static MutableDataset<ClusterID> generateDataset(int numSamples, double[] mixingDistribution, double[] firstMean, double[] firstVariance, double[] secondMean, double[] secondVariance, double[] thirdMean, double[] thirdVariance, double[] fourthMean, double[] fourthVariance, double[] fifthMean, double[] fifthVariance, long seed) {
        GaussianClusterDataSource source = new GaussianClusterDataSource(numSamples, mixingDistribution, firstMean, firstVariance, secondMean, secondVariance, thirdMean, thirdVariance, fourthMean, fourthVariance, fifthMean, fifthVariance, seed);
        return new MutableDataset((DataSource)source);
    }

    public static final class GaussianClusterDataSourceProvenance
    extends SkeletalConfiguredObjectProvenance
    implements ConfiguredDataSourceProvenance {
        private static final long serialVersionUID = 1L;

        GaussianClusterDataSourceProvenance(GaussianClusterDataSource host) {
            super((Configurable)host, "DataSource");
        }

        public GaussianClusterDataSourceProvenance(Map<String, Provenance> map) {
            this(GaussianClusterDataSourceProvenance.extractProvenanceInfo(map));
        }

        private GaussianClusterDataSourceProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo info) {
            super(info);
        }

        protected static SkeletalConfiguredObjectProvenance.ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) {
            HashMap<String, Provenance> configuredParameters = new HashMap<String, Provenance>(map);
            String className = ((StringProvenance)ObjectProvenance.checkAndExtractProvenance(configuredParameters, (String)"class-name", StringProvenance.class, (String)GaussianClusterDataSourceProvenance.class.getSimpleName())).getValue();
            String hostTypeStringName = ((StringProvenance)ObjectProvenance.checkAndExtractProvenance(configuredParameters, (String)"host-short-name", StringProvenance.class, (String)GaussianClusterDataSourceProvenance.class.getSimpleName())).getValue();
            return new SkeletalConfiguredObjectProvenance.ExtractedInfo(className, hostTypeStringName, configuredParameters, Collections.emptyMap());
        }
    }
}

