/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.meta;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.AttributeFactory;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.Model;
import edu.udo.cs.yale.operator.OperatorCreationException;
import edu.udo.cs.yale.operator.OperatorDescription;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.learner.Learner;
import edu.udo.cs.yale.operator.learner.LearnerCapability;
import edu.udo.cs.yale.operator.learner.lazy.DefaultLearner;
import edu.udo.cs.yale.operator.learner.meta.AbstractMetaLearner;
import edu.udo.cs.yale.operator.learner.meta.AdditiveRegressionModel;
import edu.udo.cs.yale.operator.parameter.ParameterType;
import edu.udo.cs.yale.operator.parameter.ParameterTypeDouble;
import edu.udo.cs.yale.operator.parameter.ParameterTypeInt;
import edu.udo.cs.yale.tools.OperatorService;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class AdditiveRegression
extends AbstractMetaLearner {
    public AdditiveRegression(OperatorDescription description) {
        super(description);
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        ExampleSet workingExampleSet = (ExampleSet)exampleSet.clone();
        Attribute workingLabel = AttributeFactory.createAttribute(workingExampleSet.getAttributes().getLabel(), "working_label");
        workingExampleSet.getExampleTable().addAttribute(workingLabel);
        workingExampleSet.getAttributes().addRegular(workingLabel);
        for (Example example : workingExampleSet) {
            example.setValue(workingLabel, example.getLabel());
        }
        workingExampleSet.getAttributes().remove(workingLabel);
        workingExampleSet.getAttributes().setLabel(workingLabel);
        Learner defaultLearner = null;
        try {
            defaultLearner = (Learner)((Object)OperatorService.createOperator(DefaultLearner.class));
        }
        catch (OperatorCreationException e) {
            throw new OperatorException(String.valueOf(this.getName()) + ": not able to create default classifier!", e);
        }
        Model defaultModel = defaultLearner.learn(exampleSet);
        this.residualReplace(workingExampleSet, defaultModel, false);
        Model[] residualModels = new Model[this.getParameterAsInt("iterations")];
        int iteration = 0;
        while (iteration < residualModels.length) {
            residualModels[iteration] = this.applyInnerLearner(workingExampleSet);
            this.residualReplace(workingExampleSet, residualModels[iteration], true);
            ++iteration;
        }
        workingExampleSet.getAttributes().remove(workingLabel);
        workingExampleSet.getExampleTable().removeAttribute(workingLabel);
        return new AdditiveRegressionModel(exampleSet.getAttributes().getLabel(), defaultModel, residualModels, this.getParameterAsDouble("shrinkage"));
    }

    private void residualReplace(ExampleSet exampleSet, Model model, boolean shrinkage) throws OperatorException {
        model.apply(exampleSet);
        Attribute label = exampleSet.getAttributes().getLabel();
        for (Example example : exampleSet) {
            double prediction = example.getPredictedLabel();
            if (shrinkage) {
                prediction *= this.getParameterAsDouble("shrinkage");
            }
            double residual = example.getLabel() - prediction;
            example.setValue(label, residual);
        }
    }

    @Override
    public int getMinNumberOfInnerOperators() {
        return 1;
    }

    @Override
    public int getMaxNumberOfInnerOperators() {
        return 1;
    }

    @Override
    public boolean supportsCapability(LearnerCapability capability) {
        if (capability.equals(LearnerCapability.NUMERICAL_CLASS)) {
            return true;
        }
        return super.supportsCapability(capability);
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeInt("iterations", "The number of iterations.", 1, Integer.MAX_VALUE, 10));
        types.add(new ParameterTypeDouble("shrinkage", "Reducing this learning rate prevent overfitting but increases the learning time.", 0.0, 1.0, 1.0));
        return types;
    }
}

