/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.validation;

import edu.udo.cs.yale.Statistics;
import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleReader;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.gui.SimplePlotterDialog;
import edu.udo.cs.yale.operator.IOObject;
import edu.udo.cs.yale.operator.Operator;
import edu.udo.cs.yale.operator.OperatorDescription;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.UserError;
import edu.udo.cs.yale.operator.parameter.ParameterTypeBoolean;
import edu.udo.cs.yale.operator.parameter.ParameterTypeDouble;
import edu.udo.cs.yale.operator.performance.AUCCriterion;
import edu.udo.cs.yale.operator.performance.PerformanceVector;
import edu.udo.cs.yale.operator.validation.Threshold;
import edu.udo.cs.yale.tools.math.Averagable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public class ThresholdFinder
extends Operator {
    public static final int MAX_ROC_POINTS = 200;
    private static final String COSTS_NEG = "misclassification_costs_first";
    private static final String COSTS_POS = "misclassification_costs_second";
    private static final String SHOW_PLOT = "show_roc_plot";
    private static final String CREATE_AUC = "create_AUC_performance";
    static /* synthetic */ Class class$edu$udo$cs$yale$example$ExampleSet;
    static /* synthetic */ Class class$edu$udo$cs$yale$operator$validation$Threshold;
    static /* synthetic */ Class class$edu$udo$cs$yale$operator$performance$PerformanceVector;

    public ThresholdFinder(OperatorDescription description) {
        super(description);
    }

    public IOObject[] apply() throws OperatorException {
        double costRatio;
        ExampleSet exampleSet = (ExampleSet)this.getInput(class$edu$udo$cs$yale$example$ExampleSet == null ? (class$edu$udo$cs$yale$example$ExampleSet = ThresholdFinder.class$("edu.udo.cs.yale.example.ExampleSet")) : class$edu$udo$cs$yale$example$ExampleSet);
        Attribute label = exampleSet.getLabel();
        if (label == null) {
            throw new UserError(this, 105);
        }
        if (!label.isNominal()) {
            throw new UserError((Operator)this, 101, label, (Object)"threshold finding");
        }
        if (!label.isBooleanClassification()) {
            throw new UserError((Operator)this, 118, new Object[]{label, new Integer(label.getValues().size()), new Integer(2)});
        }
        ExampleReader reader = exampleSet.getExampleReader();
        Object[] calArray = new WeightedConfidenceAndLabel[exampleSet.getSize()];
        Attribute weightAttr = exampleSet.getWeight();
        Attribute labelAttr = exampleSet.getLabel();
        Attribute predictionAttr = exampleSet.getPredictedLabel();
        int negativeLabelIndex = label.getNegativeIndex();
        int positiveLabelIndex = label.getPositiveIndex();
        String posMappe = labelAttr.mapIndex(positiveLabelIndex);
        int index = 0;
        while (reader.hasNext()) {
            Example example = reader.next();
            WeightedConfidenceAndLabel wcl = weightAttr == null ? new WeightedConfidenceAndLabel(example.getConfidence(posMappe), example.getValue(labelAttr)) : new WeightedConfidenceAndLabel(example.getConfidence(posMappe), example.getValue(labelAttr), example.getValue(weightAttr));
            calArray[index++] = wcl;
        }
        Arrays.sort(calArray);
        double slope = costRatio = this.getParameterAsDouble(COSTS_NEG) / this.getParameterAsDouble(COSTS_POS);
        double tp = 0.0;
        double sum = 0.0;
        double bestIsometricsTpValue = 0.0;
        double bestThreshold = Double.POSITIVE_INFINITY;
        LinkedList<double[]> statsData = new LinkedList<double[]>();
        statsData.add(new double[]{0.0, 0.0, 1.0});
        for (int i = 0; i < calArray.length; ++i) {
            Object wcl = calArray[i];
            double weight = ((WeightedConfidenceAndLabel)wcl).getWeight();
            double fp = sum - tp;
            if (((WeightedConfidenceAndLabel)wcl).getLabel() == (double)positiveLabelIndex) {
                tp += weight;
            } else {
                double c = tp - fp * slope;
                if (c > bestIsometricsTpValue) {
                    bestIsometricsTpValue = c;
                    bestThreshold = ((WeightedConfidenceAndLabel)wcl).getConfidence();
                }
            }
            statsData.add(new double[]{fp, tp, ((WeightedConfidenceAndLabel)wcl).getConfidence()});
            sum += weight;
        }
        double c = tp - (sum - tp) * slope;
        if (c > bestIsometricsTpValue) {
            bestThreshold = Double.NEGATIVE_INFINITY;
            bestIsometricsTpValue = c;
        }
        double sumPos = tp;
        double sumNeg = sum - sumPos;
        bestIsometricsTpValue /= sumPos;
        statsData.add(new double[]{sumNeg, sumPos, 0.0});
        double aucSum = 0.0;
        double[] last = null;
        Statistics stats = new Statistics("ROCplot");
        stats.init(new String[]{"FP/N", "TP/P", "Slope", "Treshold"});
        Iterator i = statsData.iterator();
        boolean first = true;
        int pointCounter = 0;
        int eachPoint = (int)Math.round((double)statsData.size() / 200.0);
        while (i.hasNext()) {
            double[] point = (double[])i.next();
            double fpDivN = point[0] / sumNeg;
            double tpDivP = point[1] / sumPos;
            if (eachPoint < 1 || pointCounter % eachPoint == 0 || !i.hasNext()) {
                stats.add(new Object[]{new Double(fpDivN), new Double(tpDivP), new Double(bestIsometricsTpValue + fpDivN * slope * (sumNeg / sumPos)), new Double(point[2])});
            }
            if (last != null) {
                aucSum += (tpDivP - last[1]) * (fpDivN - last[0]) / 2.0 + last[1] * (fpDivN - last[0]);
            }
            last = new double[]{fpDivN, tpDivP};
            ++pointCounter;
        }
        if (this.getParameterAsBoolean(SHOW_PLOT)) {
            SimplePlotterDialog plotter = new SimplePlotterDialog(stats);
            plotter.setXAxis(0);
            plotter.plotColumn(1, true);
            plotter.plotColumn(2, true);
            plotter.plotColumn(3, true);
            plotter.setDrawRange(0.0, 1.0, 0.0, 1.0);
            plotter.setVisible(true);
        }
        if (this.getParameterAsBoolean(CREATE_AUC)) {
            PerformanceVector aucPerformanceVector = new PerformanceVector();
            AUCCriterion aucCriterion = new AUCCriterion(label.mapIndex(positiveLabelIndex), aucSum);
            aucPerformanceVector.addCriterion(aucCriterion);
            aucPerformanceVector.setMainCriterionName(((Averagable)aucCriterion).getName());
            return new IOObject[]{exampleSet, new Threshold(bestThreshold, label.mapIndex(negativeLabelIndex), label.mapIndex(positiveLabelIndex)), aucPerformanceVector};
        }
        return new IOObject[]{exampleSet, new Threshold(bestThreshold, label.mapIndex(negativeLabelIndex), label.mapIndex(positiveLabelIndex))};
    }

    public Class[] getInputClasses() {
        return new Class[]{class$edu$udo$cs$yale$example$ExampleSet == null ? (class$edu$udo$cs$yale$example$ExampleSet = ThresholdFinder.class$("edu.udo.cs.yale.example.ExampleSet")) : class$edu$udo$cs$yale$example$ExampleSet};
    }

    public Class[] getOutputClasses() {
        if (this.getParameterAsBoolean(CREATE_AUC)) {
            return new Class[]{class$edu$udo$cs$yale$example$ExampleSet == null ? (class$edu$udo$cs$yale$example$ExampleSet = ThresholdFinder.class$("edu.udo.cs.yale.example.ExampleSet")) : class$edu$udo$cs$yale$example$ExampleSet, class$edu$udo$cs$yale$operator$validation$Threshold == null ? (class$edu$udo$cs$yale$operator$validation$Threshold = ThresholdFinder.class$("edu.udo.cs.yale.operator.validation.Threshold")) : class$edu$udo$cs$yale$operator$validation$Threshold, class$edu$udo$cs$yale$operator$performance$PerformanceVector == null ? (class$edu$udo$cs$yale$operator$performance$PerformanceVector = ThresholdFinder.class$("edu.udo.cs.yale.operator.performance.PerformanceVector")) : class$edu$udo$cs$yale$operator$performance$PerformanceVector};
        }
        return new Class[]{class$edu$udo$cs$yale$example$ExampleSet == null ? (class$edu$udo$cs$yale$example$ExampleSet = ThresholdFinder.class$("edu.udo.cs.yale.example.ExampleSet")) : class$edu$udo$cs$yale$example$ExampleSet, class$edu$udo$cs$yale$operator$validation$Threshold == null ? (class$edu$udo$cs$yale$operator$validation$Threshold = ThresholdFinder.class$("edu.udo.cs.yale.operator.validation.Threshold")) : class$edu$udo$cs$yale$operator$validation$Threshold};
    }

    public List getParameterTypes() {
        List list = super.getParameterTypes();
        list.add(new ParameterTypeDouble(COSTS_NEG, "The costs assigned when an example of the first class is classified as one of the second.", 0.0, Double.POSITIVE_INFINITY, 1.0));
        list.add(new ParameterTypeDouble(COSTS_POS, "The costs assigned when an example of the second class is classified as one of the first.", 0.0, Double.POSITIVE_INFINITY, 1.0));
        list.add(new ParameterTypeBoolean(SHOW_PLOT, "Display a plot of the ROC curve.", false));
        list.add(new ParameterTypeBoolean(CREATE_AUC, "Indicates if the area under the ROC curve should be delivered as performance criterion.", false));
        return list;
    }

    static /* synthetic */ Class class$(String x0) {
        try {
            return Class.forName(x0);
        }
        catch (ClassNotFoundException x1) {
            throw new NoClassDefFoundError(x1.getMessage());
        }
    }

    private class WeightedConfidenceAndLabel
    implements Comparable {
        private double confidence;
        private double label;
        private double weight = 1.0;

        public WeightedConfidenceAndLabel(double confidence, double label) {
            this.confidence = confidence;
            this.label = label;
        }

        public WeightedConfidenceAndLabel(double confidence, double label, double weight) {
            this(confidence, label);
            this.weight = weight;
        }

        public int compareTo(Object obj) {
            int compi = -1 * Double.compare(this.confidence, ((WeightedConfidenceAndLabel)obj).confidence);
            if (compi == 0) {
                return Double.compare(this.label, ((WeightedConfidenceAndLabel)obj).label);
            }
            return compi;
        }

        public double getLabel() {
            return this.label;
        }

        public double getConfidence() {
            return this.confidence;
        }

        public double getWeight() {
            return this.weight;
        }

        public String toString() {
            return "conf: " + this.confidence + ", label: " + this.label + ", weight: " + this.weight;
        }
    }
}

