/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.rules;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.rules.NumericalSplitter;
import com.rapidminer.operator.learner.rules.Rule;
import com.rapidminer.operator.learner.rules.RuleModel;
import com.rapidminer.operator.learner.rules.Split;
import com.rapidminer.operator.learner.tree.LessEqualsSplitCondition;
import com.rapidminer.operator.learner.tree.NominalSplitCondition;
import java.util.ArrayList;
import java.util.Collection;

public class SingleRuleLearner
extends AbstractLearner {
    private NumericalSplitter splitter = new NumericalSplitter();

    public SingleRuleLearner(OperatorDescription description) {
        super(description);
    }

    @Override
    public Model learn(ExampleSet inputSet) throws OperatorException {
        ExampleSet exampleSet = (ExampleSet)inputSet.clone();
        ArrayList<RuleModel> models = new ArrayList<RuleModel>();
        for (Attribute attribute : exampleSet.getAttributes()) {
            ExampleSet trainingSet = (ExampleSet)exampleSet.clone();
            if (attribute.isNominal()) {
                models.add(this.createNominalRuleModel(trainingSet, attribute));
                continue;
            }
            models.add(this.createNumericalRuleModel(trainingSet, attribute));
        }
        return this.getBestModel(models, exampleSet, true);
    }

    private RuleModel createNumericalRuleModel(ExampleSet trainingSet, Attribute attribute) {
        ExampleSet exampleSet;
        Split bestSplit;
        double bestSplitValue;
        RuleModel model = new RuleModel(trainingSet);
        int oldSize = -1;
        while (trainingSet.size() > 0 && trainingSet.size() != oldSize && !Double.isNaN(bestSplitValue = (bestSplit = this.splitter.getBestSplit(exampleSet = (ExampleSet)trainingSet.clone(), attribute, null)).getSplitPoint())) {
            SplittedExampleSet splittedSet = SplittedExampleSet.splitByAttribute(exampleSet, attribute, bestSplitValue);
            Attribute label = splittedSet.getAttributes().getLabel();
            splittedSet.selectSingleSubset(0);
            LessEqualsSplitCondition condition = new LessEqualsSplitCondition(attribute, bestSplitValue);
            splittedSet.recalculateAttributeStatistics(label);
            int labelValue = (int)splittedSet.getStatistics(label, "mode");
            String labelName = label.getMapping().mapIndex(labelValue);
            Rule rule = new Rule(labelName, condition);
            int[] frequencies = new int[label.getMapping().size()];
            int counter = 0;
            for (String value : label.getMapping().getValues()) {
                frequencies[counter++] = (int)splittedSet.getStatistics(label, "count", value);
            }
            rule.setFrequencies(frequencies);
            model.addRule(rule);
            oldSize = trainingSet.size();
            trainingSet = rule.removeCovered(trainingSet);
        }
        if (trainingSet.size() > 0) {
            Attribute label = trainingSet.getAttributes().getLabel();
            trainingSet.recalculateAttributeStatistics(label);
            int index = (int)trainingSet.getStatistics(label, "mode");
            String defaultLabel = label.getMapping().mapIndex(index);
            Rule defaultRule = new Rule(defaultLabel);
            int[] frequencies = new int[label.getMapping().size()];
            int counter = 0;
            for (String value : label.getMapping().getValues()) {
                frequencies[counter++] = (int)trainingSet.getStatistics(label, "count", value);
            }
            defaultRule.setFrequencies(frequencies);
            model.addRule(defaultRule);
        }
        return model;
    }

    private RuleModel createNominalRuleModel(ExampleSet exampleSet, Attribute attribute) {
        RuleModel model = new RuleModel(exampleSet);
        SplittedExampleSet splittedSet = SplittedExampleSet.splitByAttribute(exampleSet, attribute);
        Attribute label = splittedSet.getAttributes().getLabel();
        for (int i = 0; i < splittedSet.getNumberOfSubsets(); ++i) {
            splittedSet.selectSingleSubset(i);
            splittedSet.recalculateAttributeStatistics(label);
            NominalSplitCondition term = new NominalSplitCondition(attribute, attribute.getMapping().mapIndex(i));
            int labelValue = (int)splittedSet.getStatistics(label, "mode");
            String labelName = label.getMapping().mapIndex(labelValue);
            Rule rule = new Rule(labelName, term);
            int[] frequencies = new int[label.getMapping().size()];
            int counter = 0;
            for (String value : label.getMapping().getValues()) {
                frequencies[counter++] = (int)splittedSet.getStatistics(label, "count", value);
            }
            rule.setFrequencies(frequencies);
            model.addRule(rule);
        }
        return model;
    }

    private RuleModel getBestModel(Collection<RuleModel> models, ExampleSet exampleSet, boolean useExampleWeights) {
        Attribute exampleWeightAttribute = exampleSet.getAttributes().getSpecial("weight");
        useExampleWeights = useExampleWeights && exampleWeightAttribute != null;
        double[] weightedError = new double[models.size()];
        double totalWeight = 0.0;
        for (Example example : exampleSet) {
            int i = 0;
            double currentWeight = 1.0;
            if (useExampleWeights) {
                currentWeight = example.getValue(exampleWeightAttribute);
            }
            double currentLabel = example.getLabel();
            totalWeight += currentWeight;
            for (RuleModel currentModel : models) {
                if (currentLabel != currentModel.getPrediction(example)) {
                    int n = i;
                    weightedError[n] = weightedError[n] + currentWeight;
                }
                ++i;
            }
        }
        int i = 0;
        double bestError = Double.POSITIVE_INFINITY;
        RuleModel bestModel = null;
        for (RuleModel currentModel : models) {
            if (weightedError[i] < bestError) {
                bestError = weightedError[i];
                bestModel = currentModel;
            }
            ++i;
        }
        return bestModel;
    }

    @Override
    public Class<? extends PredictionModel> getModelClass() {
        return RuleModel.class;
    }

    @Override
    public boolean supportsCapability(OperatorCapability capability) {
        if (capability == OperatorCapability.BINOMINAL_ATTRIBUTES) {
            return true;
        }
        if (capability == OperatorCapability.POLYNOMINAL_ATTRIBUTES) {
            return true;
        }
        if (capability == OperatorCapability.NUMERICAL_ATTRIBUTES) {
            return true;
        }
        if (capability == OperatorCapability.POLYNOMINAL_LABEL) {
            return true;
        }
        if (capability == OperatorCapability.BINOMINAL_LABEL) {
            return true;
        }
        return capability == OperatorCapability.WEIGHTED_EXAMPLES;
    }
}

