/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.tools.math;

import edu.udo.cs.yale.datatable.SimpleDataTable;
import edu.udo.cs.yale.datatable.SimpleDataTableRow;
import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.example.NominalAttributeStatistics;
import edu.udo.cs.yale.gui.plotter.SimplePlotterDialog;
import edu.udo.cs.yale.tools.math.WeightedConfidenceAndLabel;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ROCDataGenerator {
    public static final int MAX_ROC_POINTS = 200;
    private double misclassificationCostsPositive = 1.0;
    private double misclassificationCostsNegative = 1.0;
    private double sumPos;
    private double sumNeg;
    private double bestIsometricsTpValue = 0.0;
    private double slope = 1.0;
    private double bestThreshold = Double.NaN;

    public ROCDataGenerator(double misclassificationCostsPositive, double misclassificationCostsNegative) {
        this.misclassificationCostsPositive = misclassificationCostsPositive;
        this.misclassificationCostsNegative = misclassificationCostsNegative;
    }

    public double getBestThreshold() {
        return this.bestThreshold;
    }

    public List<double[]> createROCDataList(ExampleSet exampleSet) {
        Attribute label = exampleSet.getAttributes().getLabel();
        exampleSet.recalculateAttributeStatistics(label);
        Attribute predictedLabel = exampleSet.getAttributes().getPredictedLabel();
        Object[] calArray = new WeightedConfidenceAndLabel[exampleSet.size()];
        Attribute weightAttr = exampleSet.getAttributes().getWeight();
        Attribute labelAttr = exampleSet.getAttributes().getLabel();
        String positiveClassName = labelAttr.getMapping().mapIndex(label.getMapping().getPositiveIndex());
        int index = 0;
        for (Example example : exampleSet) {
            WeightedConfidenceAndLabel wcl = weightAttr == null ? new WeightedConfidenceAndLabel(example.getConfidence(positiveClassName), example.getValue(labelAttr), example.getValue(predictedLabel)) : new WeightedConfidenceAndLabel(example.getConfidence(positiveClassName), example.getValue(labelAttr), example.getValue(predictedLabel), example.getValue(weightAttr));
            calArray[index++] = wcl;
        }
        Arrays.sort(calArray);
        NominalAttributeStatistics stats = (NominalAttributeStatistics)label.getStatistics();
        double ratio = (double)stats.getValueCount(positiveClassName) / (double)stats.getValueCount(label.getMapping().mapIndex(label.getMapping().getNegativeIndex()));
        this.slope = this.misclassificationCostsNegative / this.misclassificationCostsPositive;
        this.slope = ratio / this.slope;
        double tp = 0.0;
        double sum = 0.0;
        this.bestIsometricsTpValue = 0.0;
        this.bestThreshold = Double.POSITIVE_INFINITY;
        double oldConfidence = 1.0;
        LinkedList<double[]> tableData = new LinkedList<double[]>();
        tableData.add(new double[]{0.0, 0.0, 1.0});
        double[] last = new double[]{0.0, 0.0, 1.0};
        int i = 0;
        while (i < calArray.length) {
            Object wcl = calArray[i];
            double weight = ((WeightedConfidenceAndLabel)wcl).getWeight();
            double fp = sum - tp;
            if (((WeightedConfidenceAndLabel)wcl).getLabel() == (double)label.getMapping().getPositiveIndex()) {
                tp += weight;
            } else {
                double c = tp - fp * this.slope;
                if (c > this.bestIsometricsTpValue) {
                    this.bestIsometricsTpValue = c;
                    this.bestThreshold = ((WeightedConfidenceAndLabel)wcl).getConfidence();
                }
            }
            double currentConfidence = ((WeightedConfidenceAndLabel)wcl).getConfidence();
            if (currentConfidence != oldConfidence) {
                tableData.add(last);
                oldConfidence = currentConfidence;
            }
            last = new double[]{fp, tp, currentConfidence};
            sum += weight;
            ++i;
        }
        double c = tp - (sum - tp) * this.slope;
        if (c > this.bestIsometricsTpValue) {
            this.bestThreshold = Double.NEGATIVE_INFINITY;
            this.bestIsometricsTpValue = c;
        }
        this.sumPos = tp;
        this.sumNeg = sum - tp;
        this.bestIsometricsTpValue /= this.sumPos;
        tableData.add(new double[]{this.sumNeg, this.sumPos, 0.0});
        return tableData;
    }

    public void createROCPlot(List<double[]> data, boolean showSlope, boolean showThresholds) {
        SimpleDataTable dataTable = new SimpleDataTable("ROC Plot", new String[]{"FP/N", "TP/P", "Slope", "Threshold"});
        Iterator<double[]> i = data.iterator();
        int pointCounter = 0;
        int eachPoint = Math.max(1, (int)Math.round((double)data.size() / 200.0));
        while (i.hasNext()) {
            double[] point = i.next();
            if (pointCounter == 0 || pointCounter % eachPoint == 0 || !i.hasNext()) {
                double fpDivN = point[0] / this.sumNeg;
                double tpDivP = point[1] / this.sumPos;
                double threshold = point[2];
                dataTable.add(new SimpleDataTableRow(new double[]{fpDivN, tpDivP, this.bestIsometricsTpValue + fpDivN * this.slope * (this.sumNeg / this.sumPos), threshold}));
            }
            ++pointCounter;
        }
        SimplePlotterDialog plotter = new SimplePlotterDialog(dataTable);
        plotter.setXAxis(0);
        plotter.plotColumn(1, true);
        if (showSlope) {
            plotter.plotColumn(2, true);
        }
        if (showThresholds) {
            plotter.plotColumn(3, true);
        }
        plotter.setDrawRange(0.0, 1.0, 0.0, 1.0);
        plotter.setSize(500, 500);
        plotter.setLocationRelativeTo(plotter.getOwner());
        plotter.setVisible(true);
    }

    public double calculateAUC(List<double[]> rocData) {
        double aucSum = 0.0;
        double[] last = null;
        for (double[] point : rocData) {
            double fpDivN = point[0] / this.sumNeg;
            double tpDivP = point[1] / this.sumPos;
            if (last != null) {
                aucSum += (tpDivP - last[1]) * (fpDivN - last[0]) / 2.0 + last[1] * (fpDivN - last[0]);
            }
            last = new double[]{fpDivN, tpDivP};
        }
        return aucSum;
    }
}

