/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeTypeException;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SortedExampleSet;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.ExampleSetPrecondition;
import com.rapidminer.operator.ports.metadata.Precondition;
import com.rapidminer.operator.postprocessing.Threshold;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import java.util.List;

public class RecallChooser
extends Operator {
    private static final String PARAMETER_USE_EXAMPLE_WEIGHTS = "use_example_weights";
    private static final String PARAMETER_RECALL = "min_recall";
    private InputPort exampleSetInput = this.getInputPorts().createPort("example set", ExampleSet.class);
    private OutputPort exampleSetOutput = (OutputPort)this.getOutputPorts().createPort("example set");
    private OutputPort thresholdOutput = (OutputPort)this.getOutputPorts().createPort("threshold");

    public RecallChooser(OperatorDescription description) {
        super(description);
        this.exampleSetInput.addPrecondition((Precondition)new ExampleSetPrecondition(this.exampleSetInput, 0, new String[]{"label", "prediction", "confidence"}));
        this.getTransformer().addPassThroughRule(this.exampleSetInput, this.exampleSetOutput);
        this.getTransformer().addGenerationRule(this.thresholdOutput, Threshold.class);
    }

    public void doWork() throws OperatorException {
        ExampleSet exampleSet = (ExampleSet)this.exampleSetInput.getData();
        boolean useWeights = this.getParameterAsBoolean(PARAMETER_USE_EXAMPLE_WEIGHTS);
        Attribute label = exampleSet.getAttributes().getLabel();
        exampleSet.recalculateAttributeStatistics(label);
        if (label == null) {
            throw new UserError((Operator)this, 105);
        }
        if (!label.isNominal()) {
            throw new UserError((Operator)this, 101, new Object[]{label, "threshold finding"});
        }
        if (label.getMapping().size() != 2) {
            throw new UserError((Operator)this, 118, new Object[]{label, label.getMapping().getValues().size(), 2});
        }
        if (exampleSet.getAttributes().getPredictedLabel() == null) {
            throw new UserError((Operator)this, 107);
        }
        String positiveClassName = null;
        int positiveIndex = label.getMapping().getPositiveIndex();
        if (label.isNominal() && label.getMapping().size() == 2) {
            positiveClassName = label.getMapping().mapIndex(positiveIndex);
        } else if (label.isNominal() && label.getMapping().size() == 1) {
            positiveClassName = label.getMapping().mapIndex(0);
        } else {
            throw new AttributeTypeException("Cannot calculate ROC data for non-classification labels or for labels with more than 2 classes.");
        }
        double totalSum = 0.0;
        for (Example e : exampleSet) {
            if (e.getLabel() != (double)positiveIndex) continue;
            if (useWeights) {
                double w = e.getWeight();
                if (Double.isNaN(w)) {
                    w = 1.0;
                }
                totalSum += w;
                continue;
            }
            totalSum += 1.0;
        }
        double currentSum = 0.0;
        double desiredRecall = this.getParameterAsDouble(PARAMETER_RECALL);
        double thresholdValue = 0.0;
        Attribute confidenceAttribute = exampleSet.getAttributes().getSpecial("confidence_" + positiveClassName);
        SortedExampleSet sortedExampleSet = new SortedExampleSet(exampleSet, confidenceAttribute, 0);
        for (Example e : sortedExampleSet) {
            if (e.getLabel() != (double)positiveIndex) continue;
            if (useWeights) {
                double w = e.getWeight();
                if (Double.isNaN(w)) {
                    w = 1.0;
                }
                currentSum += w;
            } else {
                currentSum += 1.0;
            }
            if (currentSum / totalSum >= 1.0 - desiredRecall) break;
            thresholdValue = (e.getConfidence(positiveClassName) + thresholdValue) / 2.0;
        }
        this.exampleSetOutput.deliver((IOObject)exampleSet);
        this.thresholdOutput.deliver((IOObject)new Threshold(thresholdValue, label.getMapping().getNegativeString(), label.getMapping().getPositiveString()));
    }

    public List<ParameterType> getParameterTypes() {
        List list = super.getParameterTypes();
        list.add(new ParameterTypeDouble(PARAMETER_RECALL, "The minimal desired recall on the positive class.", 0.0, 1.0, 0.7, false));
        list.add(new ParameterTypeBoolean(PARAMETER_USE_EXAMPLE_WEIGHTS, "Indicates if example weights should be used.", true));
        return list;
    }
}

