/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.meta;

import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.OperatorChain;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.PortPairExtender;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetPassThroughRule;
import com.rapidminer.operator.ports.metadata.SetRelation;
import com.rapidminer.operator.ports.metadata.SubprocessTransformRule;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.RandomGenerator;
import java.util.List;

public class LearningCurveOperator
extends OperatorChain {
    private InputPort exampleSetInput = this.getInputPorts().createPort("exampleSet", ExampleSet.class);
    private OutputPort trainigSource = (OutputPort)this.getSubprocess(0).getInnerSources().createPort("training set");
    private OutputPort testSource = (OutputPort)this.getSubprocess(1).getInnerSources().createPort("test set");
    private PortPairExtender throughExtender = new PortPairExtender("through", this.getSubprocess(0).getInnerSinks(), this.getSubprocess(1).getInnerSources());
    private InputPort performanceInnerSink = this.getSubprocess(1).getInnerSinks().createPort("performance", PerformanceVector.class);
    public static final String PARAMETER_TRAINING_RATIO = "training_ratio";
    public static final String PARAMETER_STEP_FRACTION = "step_fraction";
    public static final String PARAMETER_START_FRACTION = "start_fraction";
    public static final String PARAMETER_SAMPLING_TYPE = "sampling_type";
    private double lastFraction = Double.NaN;
    private double lastPerformance = Double.NaN;
    private double lastDeviation = Double.NaN;

    public LearningCurveOperator(OperatorDescription description) {
        super(description, "Training", "Test");
        this.throughExtender.start();
        this.getTransformer().addRule(new ExampleSetPassThroughRule(this.exampleSetInput, this.trainigSource, SetRelation.EQUAL){

            @Override
            public ExampleSetMetaData modifyExampleSet(ExampleSetMetaData metaData) throws UndefinedParameterError {
                metaData.getNumberOfExamples().reduceByUnknownAmount();
                return super.modifyExampleSet(metaData);
            }
        });
        this.getTransformer().addRule(new SubprocessTransformRule(this.getSubprocess(0)));
        this.getTransformer().addRule(this.throughExtender.makePassThroughRule());
        this.getTransformer().addRule(new ExampleSetPassThroughRule(this.exampleSetInput, this.testSource, SetRelation.EQUAL){

            @Override
            public ExampleSetMetaData modifyExampleSet(ExampleSetMetaData metaData) throws UndefinedParameterError {
                metaData.getNumberOfExamples().reduceByUnknownAmount();
                return super.modifyExampleSet(metaData);
            }
        });
        this.getTransformer().addRule(new SubprocessTransformRule(this.getSubprocess(1)));
        this.addValue(new ValueDouble("fraction", "The used fraction of data."){

            @Override
            public double getDoubleValue() {
                return LearningCurveOperator.this.lastFraction;
            }
        });
        this.addValue(new ValueDouble("performance", "The last performance (main criterion)."){

            @Override
            public double getDoubleValue() {
                return LearningCurveOperator.this.lastPerformance;
            }
        });
        this.addValue(new ValueDouble("deviation", "The variance of the last performance (main criterion)."){

            @Override
            public double getDoubleValue() {
                return LearningCurveOperator.this.lastDeviation;
            }
        });
    }

    @Override
    public void doWork() throws OperatorException {
        ExampleSet originalExampleSet = (ExampleSet)this.exampleSetInput.getData();
        double trainingRatio = this.getParameterAsDouble(PARAMETER_TRAINING_RATIO);
        double stepFraction = this.getParameterAsDouble(PARAMETER_STEP_FRACTION);
        double startFraction = this.getParameterAsDouble(PARAMETER_START_FRACTION);
        int samplingType = this.getParameterAsInt(PARAMETER_SAMPLING_TYPE);
        SplittedExampleSet trainTestSplittedExamples = new SplittedExampleSet(originalExampleSet, trainingRatio, samplingType, this.getParameterAsBoolean("use_local_random_seed"), this.getParameterAsInt("local_random_seed"));
        trainTestSplittedExamples.selectSingleSubset(0);
        this.lastFraction = startFraction;
        while (this.lastFraction <= 1.0) {
            trainTestSplittedExamples.selectSingleSubset(0);
            SplittedExampleSet growingTrainingSet = new SplittedExampleSet((ExampleSet)trainTestSplittedExamples, this.lastFraction, samplingType, this.getParameterAsBoolean("use_local_random_seed"), this.getParameterAsInt("local_random_seed"));
            growingTrainingSet.selectSingleSubset(0);
            this.trainigSource.deliver(growingTrainingSet);
            this.getSubprocess(0).execute();
            trainTestSplittedExamples.selectSingleSubset(1);
            this.testSource.deliver(trainTestSplittedExamples);
            this.throughExtender.passDataThrough();
            this.getSubprocess(1).execute();
            PerformanceVector performance = (PerformanceVector)this.performanceInnerSink.getData();
            this.lastPerformance = performance.getMainCriterion().getAverage();
            this.lastDeviation = performance.getMainCriterion().getStandardDeviation();
            this.lastFraction += stepFraction;
            this.inApplyLoop();
        }
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterTypeDouble type = new ParameterTypeDouble(PARAMETER_TRAINING_RATIO, "The fraction of examples which shall be maximal used for training (dynamically growing), the rest is used for testing (fixed)", 0.0, 1.0, 0.05);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble(PARAMETER_STEP_FRACTION, "The fraction of examples which would be additionally used in each step.", 0.0, 1.0, 0.05);
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeDouble(PARAMETER_START_FRACTION, "Starts with this fraction of the training data and iteratively add step_fraction examples from the training data.", 0.0, 1.0, 0.05));
        types.add(new ParameterTypeCategory(PARAMETER_SAMPLING_TYPE, "Defines the sampling type of the cross validation (linear = consecutive subsets, shuffled = random subsets, stratified = random subsets with class distribution kept constant)", SplittedExampleSet.SAMPLING_NAMES, 2));
        types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return types;
    }
}

