/*
 * Decompiled with CFR 0.152.
 */
package org.apache.solr.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.SparseFixedBitSet;
import org.apache.solr.client.solrj.io.ClassificationEvaluation;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.search.AnalyticsQuery;
import org.apache.solr.search.DelegatingCollector;
import org.apache.solr.search.QParser;
import org.apache.solr.search.QParserPlugin;
import org.apache.solr.search.SolrIndexSearcher;

public class TextLogisticRegressionQParserPlugin
extends QParserPlugin {
    public static final String NAME = "tlogit";

    @Override
    public void init(NamedList args) {
    }

    @Override
    public QParser createParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
        return new TextLogisticRegressionQParser(qstr, localParams, params, req);
    }

    private static class TrainingParams {
        public final String feature;
        public final String[] terms;
        public final double[] idfs;
        public final String outcome;
        public final double[] weights;
        public final int interation;
        public final int positiveLabel;
        public final double threshold;
        public final double alpha;

        public TrainingParams(String feature, String[] terms, double[] idfs, String outcome, double[] weights, int interation, double alpha, int positiveLabel, double threshold) {
            this.feature = feature;
            this.terms = terms;
            this.idfs = idfs;
            this.outcome = outcome;
            this.weights = weights;
            this.alpha = alpha;
            this.interation = interation;
            this.positiveLabel = positiveLabel;
            this.threshold = threshold;
        }
    }

    private static class TextLogisticRegressionCollector
    extends DelegatingCollector {
        private TrainingParams trainingParams;
        private LeafReader leafReader;
        private double[] workingDeltas;
        private ClassificationEvaluation classificationEvaluation;
        private double[] weights;
        private ResponseBuilder rbsp;
        private NumericDocValues leafOutcomeValue;
        private double totalError;
        private SparseFixedBitSet positiveDocsSet;
        private SparseFixedBitSet docsSet;
        private IndexSearcher searcher;

        TextLogisticRegressionCollector(ResponseBuilder rbsp, IndexSearcher searcher, TrainingParams trainingParams) {
            this.trainingParams = trainingParams;
            this.workingDeltas = new double[trainingParams.weights.length];
            this.weights = Arrays.copyOf(trainingParams.weights, trainingParams.weights.length);
            this.rbsp = rbsp;
            this.classificationEvaluation = new ClassificationEvaluation();
            this.searcher = searcher;
            this.positiveDocsSet = new SparseFixedBitSet(searcher.getIndexReader().numDocs());
            this.docsSet = new SparseFixedBitSet(searcher.getIndexReader().numDocs());
        }

        @Override
        public void doSetNextReader(LeafReaderContext context) throws IOException {
            super.doSetNextReader(context);
            this.leafReader = context.reader();
            this.leafOutcomeValue = this.leafReader.getNumericDocValues(this.trainingParams.outcome);
        }

        @Override
        public void collect(int doc) throws IOException {
            int valuesDocID = this.leafOutcomeValue.docID();
            if (valuesDocID < doc) {
                valuesDocID = this.leafOutcomeValue.advance(doc);
            }
            int outcome = valuesDocID == doc ? (int)this.leafOutcomeValue.longValue() : 0;
            int n = outcome = this.trainingParams.positiveLabel == outcome ? 1 : 0;
            if (outcome == 1) {
                this.positiveDocsSet.set(this.context.docBase + doc);
            }
            this.docsSet.set(this.context.docBase + doc);
        }

        @Override
        public void finish() throws IOException {
            HashMap<Integer, double[]> docVectors = new HashMap<Integer, double[]>();
            Terms terms = ((SolrIndexSearcher)this.searcher).getSlowAtomicReader().terms(this.trainingParams.feature);
            TermsEnum termsEnum = terms == null ? TermsEnum.EMPTY : terms.iterator();
            PostingsEnum postingsEnum = null;
            int termIndex = 0;
            for (String termStr : this.trainingParams.terms) {
                BytesRef term = new BytesRef((CharSequence)termStr);
                if (termsEnum.seekExact(term)) {
                    postingsEnum = termsEnum.postings(postingsEnum);
                    while (postingsEnum.nextDoc() != Integer.MAX_VALUE) {
                        int docId = postingsEnum.docID();
                        if (!this.docsSet.get(docId)) continue;
                        double[] vector = (double[])docVectors.get(docId);
                        if (vector == null) {
                            vector = new double[this.trainingParams.terms.length + 1];
                            vector[0] = 1.0;
                            docVectors.put(docId, vector);
                        }
                        vector[termIndex + 1] = this.trainingParams.idfs[termIndex] * (1.0 + Math.log(postingsEnum.freq()));
                    }
                }
                ++termIndex;
            }
            for (Map.Entry entry : docVectors.entrySet()) {
                double[] vector = (double[])entry.getValue();
                int outcome = 0;
                if (this.positiveDocsSet.get(((Integer)entry.getKey()).intValue())) {
                    outcome = 1;
                }
                double sig = this.sigmoid(this.sum(this.multiply(vector, this.weights)));
                double error = sig - (double)outcome;
                double lastSig = this.sigmoid(this.sum(this.multiply(vector, this.trainingParams.weights)));
                this.totalError += Math.abs(lastSig - (double)outcome);
                this.classificationEvaluation.count(outcome, lastSig >= this.trainingParams.threshold ? 1 : 0);
                this.workingDeltas = this.multiply(error * this.trainingParams.alpha, vector);
                for (int i = 0; i < this.workingDeltas.length; ++i) {
                    int n = i;
                    this.weights[n] = this.weights[n] - this.workingDeltas[i];
                }
            }
            NamedList analytics = new NamedList();
            this.rbsp.rsp.add("logit", analytics);
            ArrayList<Double> arrayList = new ArrayList<Double>();
            double[] dArray = this.weights;
            int n = dArray.length;
            for (int i = 0; i < n; ++i) {
                Double d = dArray[i];
                arrayList.add(d);
            }
            analytics.add("weights", arrayList);
            analytics.add("error", (Object)this.totalError);
            analytics.add("evaluation", (Object)this.classificationEvaluation.toMap());
            analytics.add("feature", (Object)this.trainingParams.feature);
            analytics.add("positiveLabel", (Object)this.trainingParams.positiveLabel);
            if (this.delegate instanceof DelegatingCollector) {
                ((DelegatingCollector)this.delegate).finish();
            }
        }

        private double sigmoid(double in) {
            double d = 1.0 / (1.0 + Math.exp(-in));
            return d;
        }

        private double[] multiply(double[] vals, double[] weights) {
            for (int i = 0; i < vals.length; ++i) {
                this.workingDeltas[i] = vals[i] * weights[i];
            }
            return this.workingDeltas;
        }

        private double[] multiply(double d, double[] vals) {
            for (int i = 0; i < vals.length; ++i) {
                this.workingDeltas[i] = vals[i] * d;
            }
            return this.workingDeltas;
        }

        private double sum(double[] vals) {
            double d = 0.0;
            for (double val : vals) {
                d += val;
            }
            return d;
        }
    }

    private static class TextLogisticRegressionQuery
    extends AnalyticsQuery {
        private TrainingParams trainingParams;

        public TextLogisticRegressionQuery(TrainingParams trainingParams) {
            this.trainingParams = trainingParams;
        }

        @Override
        public DelegatingCollector getAnalyticsCollector(ResponseBuilder rbsp, IndexSearcher indexSearcher) {
            return new TextLogisticRegressionCollector(rbsp, indexSearcher, this.trainingParams);
        }
    }

    private static class TextLogisticRegressionQParser
    extends QParser {
        TextLogisticRegressionQParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
            super(qstr, localParams, params, req);
        }

        @Override
        public Query parse() {
            String fs = this.params.get("feature");
            String[] terms = this.params.get("terms").split(",");
            String ws = this.params.get("weights");
            String dfsStr = this.params.get("idfs");
            int iteration = this.params.getInt("iteration", 0);
            String outcome = this.params.get("outcome");
            int positiveLabel = this.params.getInt("positiveLabel", 1);
            double threshold = this.params.getDouble("threshold", 0.5);
            double alpha = this.params.getDouble("alpha", 0.01);
            double[] idfs = new double[terms.length];
            String[] idfsArr = dfsStr.split(",");
            for (int i = 0; i < idfsArr.length; ++i) {
                idfs[i] = Double.parseDouble(idfsArr[i]);
            }
            double[] weights = new double[terms.length + 1];
            if (ws != null) {
                String[] wa = ws.split(",");
                for (int i = 0; i < wa.length; ++i) {
                    weights[i] = Double.parseDouble(wa[i]);
                }
            } else {
                for (int i = 0; i < weights.length; ++i) {
                    weights[i] = 1.0;
                }
            }
            TrainingParams input = new TrainingParams(fs, terms, idfs, outcome, weights, iteration, alpha, positiveLabel, threshold);
            return new TextLogisticRegressionQuery(input);
        }
    }
}

