/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.validation;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.AttributeWeightedExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.validation.WrapperValidationChain;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.ParameterTypeSingle;
import com.rapidminer.parameter.conditions.BooleanParameterCondition;
import com.rapidminer.tools.RandomGenerator;
import java.util.List;

public class WrapperXValidation
extends WrapperValidationChain {
    public static final String PARAMETER_NUMBER_OF_VALIDATIONS = "number_of_validations";
    public static final String PARAMETER_LEAVE_ONE_OUT = "leave_one_out";
    public static final String PARAMETER_SAMPLING_TYPE = "sampling_type";
    private int number;
    private int iteration;

    public WrapperXValidation(OperatorDescription description) {
        super(description);
        this.addValue(new ValueDouble("iteration", "The number of the current iteration."){

            @Override
            public double getDoubleValue() {
                return WrapperXValidation.this.iteration;
            }
        });
    }

    @Override
    public void doWork() throws OperatorException {
        ExampleSet eSet = (ExampleSet)this.exampleSetInput.getData();
        this.number = this.getParameterAsBoolean(PARAMETER_LEAVE_ONE_OUT) ? eSet.size() : this.getParameterAsInt(PARAMETER_NUMBER_OF_VALIDATIONS);
        int samplingType = this.getParameterAsInt(PARAMETER_SAMPLING_TYPE);
        SplittedExampleSet inputSet = new SplittedExampleSet(eSet, this.number, samplingType, this.getParameterAsBoolean("use_local_random_seed"), this.getParameterAsInt("local_random_seed"));
        this.log("Starting " + this.number + "-fold method cross validation");
        PerformanceVector performanceVector = null;
        AttributeWeights globalWeights = new AttributeWeights();
        for (Attribute attribute : eSet.getAttributes()) {
            globalWeights.setWeight(attribute.getName(), 0.0);
        }
        this.iteration = 0;
        while (this.iteration < this.number) {
            inputSet.selectAllSubsetsBut(this.iteration);
            AttributeWeights weights = this.useWeightingMethod(inputSet);
            SplittedExampleSet newInputSet = (SplittedExampleSet)inputSet.clone();
            Model model = this.learn(new AttributeWeightedExampleSet(newInputSet, weights, 0.0).createCleanClone());
            newInputSet.selectSingleSubset(this.iteration);
            PerformanceVector iterationPerformance = this.evaluate(new AttributeWeightedExampleSet(newInputSet, weights, 0.0).createCleanClone(), model);
            if (performanceVector == null) {
                performanceVector = iterationPerformance;
            } else {
                for (int i = 0; i < performanceVector.size(); ++i) {
                    performanceVector.getCriterion(i).buildAverage(iterationPerformance.getCriterion(i));
                }
            }
            this.handleWeights(globalWeights, weights);
            this.setResult(iterationPerformance.getMainCriterion());
            this.inApplyLoop();
            ++this.iteration;
        }
        for (String currentName : globalWeights.getAttributeNames()) {
            globalWeights.setWeight(currentName, globalWeights.getWeight(currentName) / (double)this.number);
        }
        this.setResult(performanceVector.getMainCriterion());
        this.performanceOutput.deliver(performanceVector);
        this.attributeWeightsOutput.deliver(globalWeights);
    }

    private void handleWeights(AttributeWeights globalWeights, AttributeWeights currentWeights) {
        for (String currentName : currentWeights.getAttributeNames()) {
            double globalWeight = globalWeights.getWeight(currentName);
            double currentWeight = currentWeights.getWeight(currentName);
            if (Double.isNaN(globalWeight)) {
                globalWeights.setWeight(currentName, currentWeight);
                continue;
            }
            globalWeights.setWeight(currentName, globalWeight + currentWeight);
        }
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeBoolean(PARAMETER_LEAVE_ONE_OUT, "Set the number of validations to the number of examples. If set to true, number_of_validations is ignored", false, false));
        ParameterTypeSingle type = new ParameterTypeInt(PARAMETER_NUMBER_OF_VALIDATIONS, "Number of subsets for the crossvalidation", 2, Integer.MAX_VALUE, 10);
        type.registerDependencyCondition(new BooleanParameterCondition(this, PARAMETER_LEAVE_ONE_OUT, true, false));
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeCategory(PARAMETER_SAMPLING_TYPE, "Defines the sampling type of the cross validation (linear = consecutive subsets, shuffled = random subsets, stratified = random subsets with class distribution kept constant)", SplittedExampleSet.SAMPLING_NAMES, 2);
        type.registerDependencyCondition(new BooleanParameterCondition(this, PARAMETER_LEAVE_ONE_OUT, true, false));
        types.add(type);
        types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return types;
    }
}

