/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.tools.math;

import com.rapidminer.datatable.DataTable;
import com.rapidminer.datatable.SimpleDataTable;
import com.rapidminer.datatable.SimpleDataTableRow;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeTypeException;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.gui.plotter.SimplePlotterDialog;
import com.rapidminer.gui.viewer.ROCChartPlotter;
import com.rapidminer.tools.math.ROCBias;
import com.rapidminer.tools.math.ROCData;
import com.rapidminer.tools.math.ROCPoint;
import com.rapidminer.tools.math.WeightedConfidenceAndLabel;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import javax.swing.JDialog;

public class ROCDataGenerator
implements Serializable {
    private static final long serialVersionUID = -4473681331604071436L;
    public static final int MAX_ROC_POINTS = 200;
    private double misclassificationCostsPositive = 1.0;
    private double misclassificationCostsNegative = 1.0;
    private double slope = 1.0;
    private double bestThreshold = Double.NaN;

    public ROCDataGenerator(double misclassificationCostsPositive, double misclassificationCostsNegative) {
        this.misclassificationCostsPositive = misclassificationCostsPositive;
        this.misclassificationCostsNegative = misclassificationCostsNegative;
    }

    public double getBestThreshold() {
        return this.bestThreshold;
    }

    public ROCData createROCData(ExampleSet exampleSet, boolean useExampleWeights, ROCBias method) {
        Attribute label = exampleSet.getAttributes().getLabel();
        exampleSet.recalculateAttributeStatistics(label);
        Attribute predictedLabel = exampleSet.getAttributes().getPredictedLabel();
        WeightedConfidenceAndLabel[] calArray = new WeightedConfidenceAndLabel[exampleSet.size()];
        Attribute weightAttr = null;
        if (useExampleWeights) {
            weightAttr = exampleSet.getAttributes().getWeight();
        }
        Attribute labelAttr = exampleSet.getAttributes().getLabel();
        String positiveClassName = null;
        int positiveIndex = label.getMapping().getPositiveIndex();
        if (label.isNominal() && label.getMapping().size() == 2) {
            positiveClassName = labelAttr.getMapping().mapIndex(positiveIndex);
        } else if (label.isNominal() && label.getMapping().size() == 1) {
            positiveClassName = labelAttr.getMapping().mapIndex(0);
        } else {
            throw new AttributeTypeException("Cannot calculate ROC data for non-classification labels or for labels with more than 2 classes.");
        }
        int index = 0;
        for (Example example : exampleSet) {
            WeightedConfidenceAndLabel wcl = weightAttr == null ? new WeightedConfidenceAndLabel(example.getConfidence(positiveClassName), example.getValue(labelAttr), example.getValue(predictedLabel)) : new WeightedConfidenceAndLabel(example.getConfidence(positiveClassName), example.getValue(labelAttr), example.getValue(predictedLabel), example.getValue(weightAttr));
            calArray[index++] = wcl;
        }
        Arrays.sort(calArray, new WeightedConfidenceAndLabel.WCALComparator(method));
        double ratio = exampleSet.getStatistics(label, "count", positiveClassName) / exampleSet.getStatistics(label, "count", label.getMapping().mapIndex(label.getMapping().getNegativeIndex()));
        this.slope = this.misclassificationCostsNegative / this.misclassificationCostsPositive;
        this.slope = ratio / this.slope;
        double truePositiveWeight = 0.0;
        double totalWeight = 0.0;
        double bestIsometricsTpValue = 0.0;
        this.bestThreshold = Double.POSITIVE_INFINITY;
        double oldConfidence = 1.0;
        ROCData rocData = new ROCData();
        ROCPoint last = new ROCPoint(0.0, 0.0, 1.0);
        double oldLabel = -1.0;
        for (int i = 0; i < calArray.length; ++i) {
            WeightedConfidenceAndLabel wcl = calArray[i];
            double currentConfidence = wcl.getConfidence();
            boolean mustStartNewPoint = false;
            mustStartNewPoint |= currentConfidence != oldConfidence;
            if (method != ROCBias.NEUTRAL) {
                mustStartNewPoint |= oldLabel != wcl.getLabel();
            }
            if (mustStartNewPoint) {
                rocData.addPoint(last);
                oldConfidence = currentConfidence;
                oldLabel = wcl.getLabel();
            }
            double weight = wcl.getWeight();
            double falsePositiveWeight = totalWeight - truePositiveWeight;
            if (wcl.getLabel() == (double)positiveIndex) {
                truePositiveWeight += weight;
            } else {
                double c = truePositiveWeight - falsePositiveWeight * this.slope;
                if (c > bestIsometricsTpValue) {
                    bestIsometricsTpValue = c;
                    this.bestThreshold = wcl.getConfidence();
                }
            }
            last = new ROCPoint((totalWeight += weight) - truePositiveWeight, truePositiveWeight, currentConfidence);
        }
        rocData.addPoint(last);
        double c = truePositiveWeight - (totalWeight - truePositiveWeight) * this.slope;
        if (c > bestIsometricsTpValue) {
            this.bestThreshold = Double.NEGATIVE_INFINITY;
            bestIsometricsTpValue = c;
        }
        rocData.setTotalPositives(truePositiveWeight);
        rocData.setTotalNegatives(totalWeight - truePositiveWeight);
        rocData.setBestIsometricsTPValue(bestIsometricsTpValue / truePositiveWeight);
        return rocData;
    }

    private DataTable createDataTable(ROCData data, boolean showSlope, boolean showThresholds) {
        SimpleDataTable dataTable = new SimpleDataTable("ROC Plot", new String[]{"FP/N", "TP/P", "Slope", "Threshold"});
        Iterator<ROCPoint> i = data.iterator();
        int pointCounter = 0;
        int eachPoint = Math.max(1, (int)Math.round((double)data.getNumberOfPoints() / 200.0));
        while (i.hasNext()) {
            ROCPoint point = i.next();
            if (pointCounter == 0 || pointCounter % eachPoint == 0 || !i.hasNext()) {
                double fpRate = point.getFalsePositives() / data.getTotalNegatives();
                double tpRate = point.getTruePositives() / data.getTotalPositives();
                double threshold = point.getConfidence();
                dataTable.add(new SimpleDataTableRow(new double[]{fpRate, tpRate, data.getBestIsometricsTPValue() + fpRate * this.slope * (data.getTotalNegatives() / data.getTotalPositives()), threshold}));
            }
            ++pointCounter;
        }
        return dataTable;
    }

    public void createROCPlotDialog(ROCData data, boolean showSlope, boolean showThresholds) {
        SimplePlotterDialog plotter = new SimplePlotterDialog(this.createDataTable(data, showSlope, showThresholds));
        plotter.setXAxis(0);
        plotter.plotColumn(1, true);
        if (showSlope) {
            plotter.plotColumn(2, true);
        }
        if (showThresholds) {
            plotter.plotColumn(3, true);
        }
        plotter.setDrawRange(0.0, 1.0, 0.0, 1.0);
        plotter.setPointType(1);
        plotter.setSize(500, 500);
        plotter.setLocationRelativeTo(plotter.getOwner());
        plotter.setVisible(true);
    }

    public void createROCPlotDialog(ROCData data) {
        ROCChartPlotter plotter = new ROCChartPlotter();
        plotter.addROCData("ROC", data);
        JDialog dialog = new JDialog();
        dialog.setTitle("ROC Plot");
        dialog.add(plotter);
        dialog.setSize(500, 500);
        dialog.setLocationRelativeTo(null);
        dialog.setVisible(true);
    }

    public double calculateAUC(ROCData rocData) {
        if (rocData.getNumberOfPoints() == 2) {
            return 0.5;
        }
        double aucSum = 0.0;
        double[] last = null;
        for (ROCPoint point : rocData) {
            double fpDivN = point.getFalsePositives() / rocData.getTotalNegatives();
            double tpDivP = point.getTruePositives() / rocData.getTotalPositives();
            if (last != null) {
                double width = fpDivN - last[0];
                void leftHeight = last[1];
                double rightHeight = tpDivP;
                aucSum += leftHeight * width + (rightHeight - leftHeight) * width / 2.0;
            }
            last = new double[]{fpDivN, tpDivP};
        }
        return aucSum;
    }
}

