/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.visualization;

import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorChain;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.InputPortExtender;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.OutputPortExtender;
import com.rapidminer.operator.ports.metadata.GenerateNewMDRule;
import com.rapidminer.operator.ports.metadata.Precondition;
import com.rapidminer.operator.ports.metadata.PredictionModelMetaData;
import com.rapidminer.operator.ports.metadata.SimplePrecondition;
import com.rapidminer.operator.ports.metadata.SubprocessTransformRule;
import com.rapidminer.operator.visualization.ROCComparison;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.math.ROCBias;
import com.rapidminer.tools.math.ROCData;
import com.rapidminer.tools.math.ROCDataGenerator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;

public class ROCBasedComparisonOperator
extends OperatorChain {
    public static final String PARAMETER_NUMBER_OF_FOLDS = "number_of_folds";
    public static final String PARAMETER_SPLIT_RATIO = "split_ratio";
    public static final String PARAMETER_SAMPLING_TYPE = "sampling_type";
    public static final String PARAMETER_USE_EXAMPLE_WEIGHTS = "use_example_weights";
    private final InputPort exampleSetInput = this.getInputPorts().createPort("example set", ExampleSet.class);
    private final OutputPort exampleSetOutput = (OutputPort)this.getOutputPorts().createPort("exampleSet");
    private final OutputPort rocComparisonOutput = (OutputPort)this.getOutputPorts().createPort("rocComparison");
    private final OutputPortExtender trainingSetExtender = new OutputPortExtender("train", this.getSubprocess(0).getInnerSources());
    private final InputPortExtender modelExtender = new InputPortExtender("model", this.getSubprocess(0).getInnerSinks()){

        @Override
        public Precondition makePrecondition(InputPort inputPort) {
            return new SimplePrecondition(inputPort, new PredictionModelMetaData((Class<? extends PredictionModel>)PredictionModel.class), false);
        }
    };

    public ROCBasedComparisonOperator(OperatorDescription description) {
        super(description, "Model Generation");
        this.trainingSetExtender.start();
        this.modelExtender.start();
        this.getTransformer().addRule(this.trainingSetExtender.makePassThroughRule(this.exampleSetInput));
        this.getTransformer().addPassThroughRule(this.exampleSetInput, this.exampleSetOutput);
        this.getTransformer().addRule(new SubprocessTransformRule(this.getSubprocess(0)));
        this.getTransformer().addRule(new GenerateNewMDRule(this.rocComparisonOutput, ROCComparison.class));
    }

    @Override
    public void doWork() throws OperatorException {
        ExampleSet exampleSet = (ExampleSet)this.exampleSetInput.getData();
        if (exampleSet.getAttributes().getLabel() == null) {
            throw new UserError((Operator)this, 105);
        }
        if (!exampleSet.getAttributes().getLabel().isNominal()) {
            throw new UserError((Operator)this, 101, "ROC Comparison", exampleSet.getAttributes().getLabel());
        }
        if (exampleSet.getAttributes().getLabel().getMapping().getValues().size() != 2) {
            throw new UserError((Operator)this, 114, "ROC Comparison", exampleSet.getAttributes().getLabel());
        }
        HashMap<String, List<ROCData>> rocData = new HashMap<String, List<ROCData>>();
        int numberOfFolds = this.getParameterAsInt(PARAMETER_NUMBER_OF_FOLDS);
        if (numberOfFolds < 0) {
            double splitRatio = this.getParameterAsDouble(PARAMETER_SPLIT_RATIO);
            SplittedExampleSet eSet = new SplittedExampleSet((ExampleSet)exampleSet.clone(), splitRatio, this.getParameterAsInt(PARAMETER_SAMPLING_TYPE), this.getParameterAsBoolean("use_local_random_seed"), this.getParameterAsInt("local_random_seed"));
            PredictionModel.removePredictedLabel(eSet);
            eSet.selectSingleSubset(0);
            this.trainingSetExtender.deliverToAll(eSet, false);
            this.getSubprocess(0).execute();
            List models = this.modelExtender.getData(true);
            eSet.selectSingleSubset(1);
            for (Model model : models) {
                ExampleSet resultSet = model.apply(eSet);
                if (resultSet.getAttributes().getPredictedLabel() == null) {
                    throw new UserError((Operator)this, 107);
                }
                ROCDataGenerator rocDataGenerator = new ROCDataGenerator(1.0, 1.0);
                ROCData rocPoints = rocDataGenerator.createROCData(resultSet, this.getParameterAsBoolean(PARAMETER_USE_EXAMPLE_WEIGHTS), ROCBias.getROCBiasParameter(this));
                LinkedList<ROCData> dataList = new LinkedList<ROCData>();
                dataList.add(rocPoints);
                rocData.put(model.getSource(), dataList);
                PredictionModel.removePredictedLabel(resultSet);
            }
        } else {
            SplittedExampleSet eSet = new SplittedExampleSet((ExampleSet)exampleSet.clone(), numberOfFolds, this.getParameterAsInt(PARAMETER_SAMPLING_TYPE), this.getParameterAsBoolean("use_local_random_seed"), this.getParameterAsInt("local_random_seed"));
            PredictionModel.removePredictedLabel(eSet);
            for (int iteration = 0; iteration < numberOfFolds; ++iteration) {
                eSet.selectAllSubsetsBut(iteration);
                this.trainingSetExtender.deliverToAll(eSet, false);
                this.getSubprocess(0).execute();
                List models = this.modelExtender.getData(true);
                for (Model model : models) {
                    eSet.selectSingleSubset(iteration);
                    ExampleSet resultSet = model.apply(eSet);
                    if (resultSet.getAttributes().getPredictedLabel() == null) {
                        throw new UserError((Operator)this, 107);
                    }
                    ROCDataGenerator rocDataGenerator = new ROCDataGenerator(1.0, 1.0);
                    ROCData rocPoints = rocDataGenerator.createROCData(resultSet, this.getParameterAsBoolean(PARAMETER_USE_EXAMPLE_WEIGHTS), ROCBias.getROCBiasParameter(this));
                    LinkedList<ROCData> dataList = (LinkedList<ROCData>)rocData.get(model.getSource());
                    if (dataList == null) {
                        dataList = new LinkedList<ROCData>();
                        rocData.put(model.getSource(), dataList);
                    }
                    dataList.add(rocPoints);
                    PredictionModel.removePredictedLabel(resultSet);
                }
                this.inApplyLoop();
            }
        }
        this.exampleSetOutput.deliver(exampleSet);
        this.rocComparisonOutput.deliver(new ROCComparison(rocData));
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterTypeInt type = new ParameterTypeInt(PARAMETER_NUMBER_OF_FOLDS, "The number of folds used for a cross validation evaluation (-1: use simple split ratio).", -1, Integer.MAX_VALUE, 10);
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeDouble(PARAMETER_SPLIT_RATIO, "Relative size of the training set", 0.0, 1.0, 0.7));
        types.add(new ParameterTypeCategory(PARAMETER_SAMPLING_TYPE, "Defines the sampling type of the cross validation (linear = consecutive subsets, shuffled = random subsets, stratified = random subsets with class distribution kept constant)", SplittedExampleSet.SAMPLING_NAMES, 2));
        types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        types.add(new ParameterTypeBoolean(PARAMETER_USE_EXAMPLE_WEIGHTS, "Indicates if example weights should be regarded (use weight 1 for each example otherwise).", true));
        types.add(ROCBias.makeParameterType());
        return types;
    }
}

