/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.tree;

import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorCreationException;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.tree.RandomForestModel;
import com.rapidminer.operator.learner.tree.RandomTreeLearner;
import com.rapidminer.operator.learner.tree.TreeModel;
import com.rapidminer.operator.preprocessing.sampling.BootstrappingOperator;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.OperatorService;
import java.util.LinkedList;
import java.util.List;

public class RandomForestLearner
extends RandomTreeLearner {
    public static final String PARAMETER_NUMBER_OF_TREES = "number_of_trees";

    public RandomForestLearner(OperatorDescription description) {
        super(description);
    }

    @Override
    public Class<? extends PredictionModel> getModelClass() {
        return RandomForestModel.class;
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        BootstrappingOperator bootstrapping = null;
        try {
            bootstrapping = OperatorService.createOperator(BootstrappingOperator.class);
            bootstrapping.setParameter("use_weights", "false");
            bootstrapping.setParameter("sample_ratio", "1.0");
        }
        catch (OperatorCreationException e) {
            throw new OperatorException(this.getName() + ": cannot construct random tree learner: " + e.getMessage());
        }
        LinkedList<TreeModel> baseModels = new LinkedList<TreeModel>();
        int numberOfTrees = this.getParameterAsInt(PARAMETER_NUMBER_OF_TREES);
        for (int i = 0; i < numberOfTrees; ++i) {
            TreeModel model = (TreeModel)super.learn(bootstrapping.apply(exampleSet));
            model.setSource(this.getName());
            baseModels.add(model);
        }
        return new RandomForestModel(exampleSet, (List<TreeModel>)baseModels);
    }

    @Override
    public boolean supportsCapability(OperatorCapability capability) {
        if (capability == OperatorCapability.BINOMINAL_ATTRIBUTES) {
            return true;
        }
        if (capability == OperatorCapability.POLYNOMINAL_ATTRIBUTES) {
            return true;
        }
        if (capability == OperatorCapability.NUMERICAL_ATTRIBUTES) {
            return true;
        }
        if (capability == OperatorCapability.POLYNOMINAL_LABEL) {
            return true;
        }
        if (capability == OperatorCapability.BINOMINAL_LABEL) {
            return true;
        }
        if (capability == OperatorCapability.WEIGHTED_EXAMPLES) {
            return false;
        }
        return false;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        LinkedList<ParameterType> types = new LinkedList<ParameterType>();
        ParameterTypeInt type = new ParameterTypeInt(PARAMETER_NUMBER_OF_TREES, "The number of learned random trees.", 1, Integer.MAX_VALUE, 10);
        type.setExpert(false);
        types.add(type);
        types.addAll(super.getParameterTypes());
        return types;
    }
}

