/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.validation;

import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.example.SplittedExampleSet;
import edu.udo.cs.yale.operator.IOContainer;
import edu.udo.cs.yale.operator.IOObject;
import edu.udo.cs.yale.operator.OperatorDescription;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.parameter.ParameterType;
import edu.udo.cs.yale.operator.parameter.ParameterTypeCategory;
import edu.udo.cs.yale.operator.parameter.ParameterTypeDouble;
import edu.udo.cs.yale.operator.parameter.ParameterTypeInt;
import edu.udo.cs.yale.operator.performance.PerformanceVector;
import edu.udo.cs.yale.operator.validation.Tools;
import edu.udo.cs.yale.operator.validation.ValidationChain;
import edu.udo.cs.yale.tools.math.AverageVector;
import java.util.LinkedList;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class RandomSplitValidationChain
extends ValidationChain {
    public RandomSplitValidationChain(OperatorDescription description) {
        super(description);
    }

    @Override
    public IOObject[] estimatePerformance(ExampleSet inputSet) throws OperatorException {
        double splitRatio = this.getParameterAsDouble("split_ratio");
        SplittedExampleSet eSet = new SplittedExampleSet(inputSet, splitRatio, this.getParameterAsInt("sampling_type"), this.getParameterAsInt("local_random_seed"));
        eSet.selectSingleSubset(0);
        this.learn(eSet);
        eSet.selectSingleSubset(1);
        IOContainer evalRes = this.evaluate(eSet);
        LinkedList<AverageVector> averageVectors = new LinkedList<AverageVector>();
        Tools.handleAverages(evalRes, averageVectors);
        PerformanceVector performanceVector = Tools.getPerformanceVector(averageVectors);
        if (performanceVector != null) {
            this.setResult(performanceVector.getMainCriterion());
        }
        IOObject[] result = new AverageVector[averageVectors.size()];
        averageVectors.toArray(result);
        return result;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterTypeDouble type = new ParameterTypeDouble("split_ratio", "Relative size of the training set", 0.0, 1.0, 0.7);
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeCategory("sampling_type", "Defines the sampling type of the cross validation (linear = consecutive subsets, shuffled = random subsets, stratified = random subsets with class distribution kept constant)", SplittedExampleSet.SAMPLING_NAMES, 2));
        types.add(new ParameterTypeInt("local_random_seed", "Use the given random seed instead of global random numbers (-1: use global)", -1, Integer.MAX_VALUE, -1));
        return types;
    }
}

