/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.meta;

import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.example.SplittedExampleSet;
import edu.udo.cs.yale.operator.IOContainer;
import edu.udo.cs.yale.operator.IOObject;
import edu.udo.cs.yale.operator.Model;
import edu.udo.cs.yale.operator.Operator;
import edu.udo.cs.yale.operator.OperatorChain;
import edu.udo.cs.yale.operator.OperatorDescription;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.Value;
import edu.udo.cs.yale.operator.condition.CombinedInnerOperatorCondition;
import edu.udo.cs.yale.operator.condition.InnerOperatorCondition;
import edu.udo.cs.yale.operator.condition.SpecificInnerOperatorCondition;
import edu.udo.cs.yale.operator.parameter.ParameterType;
import edu.udo.cs.yale.operator.parameter.ParameterTypeBoolean;
import edu.udo.cs.yale.operator.parameter.ParameterTypeCategory;
import edu.udo.cs.yale.operator.parameter.ParameterTypeInt;
import edu.udo.cs.yale.tools.LogService;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class XVPrediction
extends OperatorChain {
    private static final Class[] INPUT_CLASSES = new Class[]{ExampleSet.class};
    private static final Class[] OUTPUT_CLASSES = new Class[]{ExampleSet.class};
    private int number;
    private int iteration;

    public XVPrediction(OperatorDescription description) {
        super(description);
        this.addValue(new Value("iteration", "The number of the current iteration."){

            public double getValue() {
                return XVPrediction.this.iteration;
            }
        });
    }

    @Override
    public int getMaxNumberOfInnerOperators() {
        return 2;
    }

    @Override
    public int getMinNumberOfInnerOperators() {
        return 2;
    }

    @Override
    public Class[] getInputClasses() {
        return INPUT_CLASSES;
    }

    @Override
    public Class[] getOutputClasses() {
        return OUTPUT_CLASSES;
    }

    @Override
    public InnerOperatorCondition getInnerOperatorCondition() {
        CombinedInnerOperatorCondition condition = new CombinedInnerOperatorCondition();
        condition.addCondition(new SpecificInnerOperatorCondition("Training", 0, new Class[]{ExampleSet.class}, new Class[]{Model.class}));
        condition.addCondition(new SpecificInnerOperatorCondition("Testing", 1, new Class[]{ExampleSet.class, Model.class}, new Class[]{ExampleSet.class}));
        return condition;
    }

    private Operator getLearner() {
        return this.getOperator(0);
    }

    private Operator getApplier() {
        return this.getOperator(1);
    }

    @Override
    public IOObject[] apply() throws OperatorException {
        ExampleSet inputSet = this.getInput(ExampleSet.class);
        this.number = this.getParameterAsBoolean("leave_one_out") ? inputSet.size() : this.getParameterAsInt("number_of_validations");
        LogService.logMessage(String.valueOf(this.getName()) + ": Starting " + this.number + "-fold cross validation prediction", 2);
        int samplingType = this.getParameterAsInt("sampling_type");
        SplittedExampleSet splittedES = new SplittedExampleSet(inputSet, this.number, samplingType, this.getParameterAsInt("local_random_seed"));
        double[] res = new double[inputSet.size()];
        double[][] confidences = null;
        if (inputSet.getAttributes().getLabel().isNominal()) {
            confidences = new double[inputSet.size()][inputSet.getAttributes().getLabel().getMapping().size()];
        }
        this.iteration = 0;
        while (this.iteration < this.number) {
            splittedES.selectAllSubsetsBut(this.iteration);
            IOContainer learnResult = this.getLearner().apply(new IOContainer(new IOObject[]{splittedES}));
            splittedES.selectSingleSubset(this.iteration);
            this.getApplier().apply(learnResult.append(new IOObject[]{splittedES}));
            int i = 0;
            while (i < splittedES.size()) {
                Example e = splittedES.getExample(i);
                double val = e.getPredictedLabel();
                int index = splittedES.getActualParentIndex(i);
                res[index] = val;
                if (confidences != null) {
                    int counter = 0;
                    for (String s : inputSet.getAttributes().getLabel().getMapping().getValues()) {
                        confidences[index][counter++] = e.getConfidence(s);
                    }
                }
                ++i;
            }
            this.inApplyLoop();
            ++this.iteration;
        }
        int index = 0;
        for (Example e : inputSet) {
            e.setValue(e.getAttributes().getPredictedLabel(), res[index]);
            if (confidences != null) {
                int counter = 0;
                for (String s : inputSet.getAttributes().getLabel().getMapping().getValues()) {
                    e.setConfidence(s, confidences[index][counter++]);
                }
            }
            ++index;
        }
        return new IOObject[]{inputSet};
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterTypeInt type = new ParameterTypeInt("number_of_validations", "Number of subsets for the crossvalidation.", 2, Integer.MAX_VALUE, 10);
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeBoolean("leave_one_out", "Set the number of validations to the number of examples. If set to true, number_of_validations is ignored.", false));
        types.add(new ParameterTypeCategory("sampling_type", "Defines the sampling type of the cross validation.", SplittedExampleSet.SAMPLING_NAMES, 2));
        types.add(new ParameterTypeInt("local_random_seed", "Use the given random seed instead of global random numbers (-1: use global).", -1, Integer.MAX_VALUE, -1));
        return types;
    }
}

