/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.validation;

import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.ProcessStoppedException;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.ports.metadata.MDInteger;
import com.rapidminer.operator.validation.ValidationChain;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.ParameterTypeSingle;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.parameter.conditions.BooleanParameterCondition;
import com.rapidminer.parameter.conditions.EqualTypeCondition;
import com.rapidminer.tools.RandomGenerator;
import java.util.List;

public class XValidation
extends ValidationChain {
    public static final String PARAMETER_NUMBER_OF_VALIDATIONS = "number_of_validations";
    public static final String PARAMETER_LEAVE_ONE_OUT = "leave_one_out";
    public static final String PARAMETER_SAMPLING_TYPE = "sampling_type";
    public static final String PARAMETER_AVERAGE_PERFORMANCES_ONLY = "average_performances_only";
    private int iteration;

    public XValidation(OperatorDescription description) {
        super(description);
        this.addValue(new ValueDouble("iteration", "The number of the current iteration."){

            @Override
            public double getDoubleValue() {
                return XValidation.this.iteration;
            }
        });
    }

    @Override
    public void estimatePerformance(ExampleSet inputSet) throws OperatorException {
        int number = this.getParameterAsBoolean(PARAMETER_LEAVE_ONE_OUT) ? inputSet.size() : this.getParameterAsInt(PARAMETER_NUMBER_OF_VALIDATIONS);
        this.getLogger().fine("Starting " + number + "-fold cross validation");
        int samplingType = this.getParameterAsInt(PARAMETER_SAMPLING_TYPE);
        SplittedExampleSet splittedES = new SplittedExampleSet(inputSet, number, samplingType, this.getParameterAsBoolean("use_local_random_seed"), this.getParameterAsInt("local_random_seed"));
        this.iteration = 0;
        while (this.iteration < number) {
            this.performIteration(splittedES, this.iteration);
            ++this.iteration;
        }
    }

    protected void performIteration(SplittedExampleSet splittedES, int iteration) throws OperatorException, ProcessStoppedException {
        splittedES.selectAllSubsetsBut(iteration);
        this.learn(splittedES);
        splittedES.selectSingleSubset(iteration);
        this.evaluate(splittedES);
        this.inApplyLoop();
    }

    @Override
    protected MDInteger getTestSetSize(MDInteger originalSize) throws UndefinedParameterError {
        if (this.getParameterAsBoolean(PARAMETER_LEAVE_ONE_OUT)) {
            return new MDInteger(1);
        }
        return originalSize.multiply(1.0 / this.getParameterAsDouble(PARAMETER_NUMBER_OF_VALIDATIONS));
    }

    @Override
    protected MDInteger getTrainingSetSize(MDInteger originalSize) throws UndefinedParameterError {
        if (this.getParameterAsBoolean(PARAMETER_LEAVE_ONE_OUT)) {
            return originalSize.add(-1);
        }
        return originalSize.multiply(1.0 - 1.0 / this.getParameterAsDouble(PARAMETER_NUMBER_OF_VALIDATIONS));
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeBoolean(PARAMETER_AVERAGE_PERFORMANCES_ONLY, "Indicates if only performance vectors should be averaged or all types of averagable result vectors", true));
        types.add(new ParameterTypeBoolean(PARAMETER_LEAVE_ONE_OUT, "Set the number of validations to the number of examples. If set to true, number_of_validations is ignored", false, false));
        ParameterTypeSingle type = new ParameterTypeInt(PARAMETER_NUMBER_OF_VALIDATIONS, "Number of subsets for the crossvalidation.", 2, Integer.MAX_VALUE, 10);
        type.registerDependencyCondition(new BooleanParameterCondition(this, PARAMETER_LEAVE_ONE_OUT, false, false));
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeCategory(PARAMETER_SAMPLING_TYPE, "Defines the sampling type of the cross validation (linear = consecutive subsets, shuffled = random subsets, stratified = random subsets with class distribution kept constant)", SplittedExampleSet.SAMPLING_NAMES, 2);
        type.registerDependencyCondition(new BooleanParameterCondition(this, PARAMETER_LEAVE_ONE_OUT, false, false));
        types.add(type);
        for (ParameterType addType : RandomGenerator.getRandomGeneratorParameters(this)) {
            addType.registerDependencyCondition(new BooleanParameterCondition(this, PARAMETER_LEAVE_ONE_OUT, false, false));
            addType.registerDependencyCondition(new EqualTypeCondition(this, PARAMETER_SAMPLING_TYPE, SplittedExampleSet.SAMPLING_NAMES, false, 1, 2));
            types.add(addType);
        }
        return types;
    }
}

