/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.tools.math;

import com.rapidminer.datatable.SimpleDataTable;
import com.rapidminer.datatable.SimpleDataTableRow;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.gui.plotter.SimplePlotterDialog;
import com.rapidminer.tools.math.WeightedConfidenceAndLabel;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public class LiftDataGenerator {
    public static final int MAX_LIFT_POINTS = 500;
    private static final int TP = 0;
    private static final int FP = 1;
    private static final int FN = 2;
    private static final int TN = 3;
    private double maxLift = 0.0;

    public List<double[]> createLiftDataList(ExampleSet exampleSet) {
        Attribute label = exampleSet.getAttributes().getLabel();
        Attribute predictedLabel = exampleSet.getAttributes().getPredictedLabel();
        Object[] calArray = new WeightedConfidenceAndLabel[exampleSet.size()];
        Attribute weightAttr = exampleSet.getAttributes().getWeight();
        Attribute labelAttr = exampleSet.getAttributes().getLabel();
        String positiveClassName = labelAttr.getMapping().mapIndex(label.getMapping().getPositiveIndex());
        int index = 0;
        for (Example example : exampleSet) {
            WeightedConfidenceAndLabel wcl = weightAttr == null ? new WeightedConfidenceAndLabel(-1.0 * example.getConfidence(positiveClassName), example.getValue(labelAttr), example.getValue(predictedLabel)) : new WeightedConfidenceAndLabel(-1.0 * example.getConfidence(positiveClassName), example.getValue(labelAttr), example.getValue(weightAttr), example.getValue(predictedLabel));
            calArray[index++] = wcl;
        }
        Arrays.sort(calArray);
        LinkedList<double[]> tableData = new LinkedList<double[]>();
        double[] confidenceMatrix = new double[4];
        this.maxLift = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < calArray.length; ++i) {
            Object wcl = calArray[i];
            double weight = ((WeightedConfidenceAndLabel)wcl).getWeight();
            double labelValue = ((WeightedConfidenceAndLabel)wcl).getLabel();
            double predictionValue = ((WeightedConfidenceAndLabel)wcl).getPrediction();
            if (labelValue == (double)label.getMapping().getPositiveIndex()) {
                if (predictionValue == (double)label.getMapping().getPositiveIndex()) {
                    confidenceMatrix[0] = confidenceMatrix[0] + weight;
                } else {
                    confidenceMatrix[2] = confidenceMatrix[2] + weight;
                }
            } else if (predictionValue == (double)label.getMapping().getPositiveIndex()) {
                confidenceMatrix[1] = confidenceMatrix[1] + weight;
            } else {
                confidenceMatrix[3] = confidenceMatrix[3] + weight;
            }
            double lift = confidenceMatrix[0] * (confidenceMatrix[1] + confidenceMatrix[3]) / ((confidenceMatrix[0] + confidenceMatrix[1]) * (confidenceMatrix[0] + confidenceMatrix[2]));
            if (!Double.isNaN(lift)) {
                this.maxLift = Math.max(lift, this.maxLift);
            }
            tableData.add(new double[]{i, lift});
        }
        return tableData;
    }

    public void createLiftChartPlot(List<double[]> data) {
        SimpleDataTable dataTable = new SimpleDataTable("Lift Chart", new String[]{"Fraction", "Lift"});
        Iterator<double[]> i = data.iterator();
        int pointCounter = 0;
        int eachPoint = Math.max(1, (int)Math.round((double)data.size() / 500.0));
        while (i.hasNext()) {
            double[] point = i.next();
            if (pointCounter == 0 || pointCounter % eachPoint == 0 || !i.hasNext()) {
                double fraction = point[0];
                double lift = point[1];
                if (Double.isNaN(lift)) {
                    lift = this.maxLift;
                }
                dataTable.add(new SimpleDataTableRow(new double[]{fraction, lift}));
            }
            ++pointCounter;
        }
        SimplePlotterDialog plotter = new SimplePlotterDialog(dataTable);
        plotter.setXAxis(0);
        plotter.plotColumn(1, true);
        plotter.setVisible(true);
    }
}

