/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.features.weighting;

import com.rapidminer.example.AttributeWeights;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.tree.AbstractTreeLearner;
import com.rapidminer.operator.learner.tree.Edge;
import com.rapidminer.operator.learner.tree.RandomForestModel;
import com.rapidminer.operator.learner.tree.Tree;
import com.rapidminer.operator.learner.tree.TreeModel;
import com.rapidminer.operator.learner.tree.criterions.AbstractCriterion;
import com.rapidminer.operator.learner.tree.criterions.Criterion;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.ModelMetaData;
import com.rapidminer.operator.ports.metadata.SimplePrecondition;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeStringCategory;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

public class ForestBasedWeighting
extends Operator {
    public static final String PARAMETER_CRITERION = "criterion";
    private InputPort forestInput = (InputPort)this.getInputPorts().createPort("random forest");
    private OutputPort weightsOutput = (OutputPort)this.getOutputPorts().createPort("weights");
    private OutputPort forestOutput = (OutputPort)this.getOutputPorts().createPort("random forest");

    public ForestBasedWeighting(OperatorDescription description) {
        super(description);
        this.forestInput.addPrecondition(new SimplePrecondition(this.forestInput, new ModelMetaData(RandomForestModel.class, new ExampleSetMetaData()), true));
        this.getTransformer().addPassThroughRule(this.forestInput, this.forestOutput);
        this.getTransformer().addGenerationRule(this.weightsOutput, AttributeWeights.class);
    }

    @Override
    public void doWork() throws OperatorException {
        RandomForestModel forest = this.forestInput.getData(RandomForestModel.class);
        String[] labelValues = forest.getTrainingHeader().getAttributes().getLabel().getMapping().getValues().toArray(new String[0]);
        Criterion criterion = AbstractCriterion.createCriterion(this, 0.0);
        HashMap<String, Double> attributeBenefitMap = new HashMap<String, Double>();
        for (Model model : forest.getModels()) {
            TreeModel treeModel = (TreeModel)model;
            this.extractWeights(attributeBenefitMap, criterion, treeModel.getRoot(), labelValues);
        }
        AttributeWeights weights = new AttributeWeights();
        int n = forest.getModels().size();
        for (Map.Entry entry : attributeBenefitMap.entrySet()) {
            weights.setWeight((String)entry.getKey(), (Double)entry.getValue() / (double)n);
        }
        if (this.getParameterAsBoolean("normalize_weights")) {
            weights.normalize();
        }
        this.weightsOutput.deliver(weights);
        this.forestOutput.deliver(forest);
    }

    private void extractWeights(HashMap<String, Double> attributeBenefitMap, Criterion criterion, Tree root, String[] labelValues) {
        if (!root.isLeaf()) {
            int numberOfChildren = root.getNumberOfChildren();
            double[][] weights = new double[numberOfChildren][];
            String attributeName = null;
            Iterator<Edge> childIterator = root.childIterator();
            int i = 0;
            while (childIterator.hasNext()) {
                Edge edge = childIterator.next();
                attributeName = edge.getCondition().getAttributeName();
                Map<String, Integer> subtreeCounterMap = edge.getChild().getSubtreeCounterMap();
                weights[i] = new double[labelValues.length];
                for (int j = 0; j < labelValues.length; ++j) {
                    Integer weight = subtreeCounterMap.get(labelValues[j]);
                    double weightValue = 0.0;
                    if (weight != null) {
                        weightValue = weight.intValue();
                    }
                    weights[i][j] = weightValue;
                }
                ++i;
            }
            double benefit = criterion.getBenefit(weights);
            Double knownBenefit = attributeBenefitMap.get(attributeName);
            if (knownBenefit != null) {
                benefit += knownBenefit.doubleValue();
            }
            attributeBenefitMap.put(attributeName, benefit);
            childIterator = root.childIterator();
            while (childIterator.hasNext()) {
                Tree child = childIterator.next().getChild();
                this.extractWeights(attributeBenefitMap, criterion, child, labelValues);
            }
        }
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterTypeStringCategory type = new ParameterTypeStringCategory(PARAMETER_CRITERION, "Specifies the used criterion for weighting attributes.", AbstractTreeLearner.CRITERIA_NAMES, AbstractTreeLearner.CRITERIA_NAMES[0], false);
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeBoolean("normalize_weights", "Activates the normalization of all weights.", true, false));
        return types;
    }
}

