/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.rules;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.rules.AccuracyCriterion;
import com.rapidminer.operator.learner.rules.Criterion;
import com.rapidminer.operator.learner.rules.InfoGainCriterion;
import com.rapidminer.operator.learner.rules.Rule;
import com.rapidminer.operator.learner.rules.RuleModel;
import com.rapidminer.operator.learner.rules.TermDetermination;
import com.rapidminer.operator.learner.tree.EmptyTermination;
import com.rapidminer.operator.learner.tree.NoAttributeLeftTermination;
import com.rapidminer.operator.learner.tree.SplitCondition;
import com.rapidminer.operator.learner.tree.Terminator;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeSingle;
import com.rapidminer.parameter.ParameterTypeStringCategory;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.Tools;
import java.util.LinkedList;
import java.util.List;

public class RuleLearner
extends AbstractLearner {
    private static final String PARAMETER_SAMPLE_RATIO = "sample_ratio";
    private static final String PARAMETER_MINIMAL_PRUNE_BENEFIT = "minimal_prune_benefit";
    public static final String[] CRITERIA_NAMES = new String[]{"information_gain", "accuracy"};
    public static final Class[] CRITERIA_CLASSES = new Class[]{InfoGainCriterion.class, AccuracyCriterion.class};
    public static final int CRITERION_INFO_GAIN = 0;
    public static final int CRITERION_ACCURACY = 1;
    private List<Terminator> terminators = new LinkedList<Terminator>();

    public RuleLearner(OperatorDescription description) {
        super(description);
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        this.terminators.add(new EmptyTermination());
        this.terminators.add(new NoAttributeLeftTermination());
        double pureness = this.getParameterAsDouble("pureness");
        double sampleRatio = this.getParameterAsDouble(PARAMETER_SAMPLE_RATIO);
        double minimalPruneBenefit = this.getParameterAsDouble(PARAMETER_MINIMAL_PRUNE_BENEFIT);
        Attribute label = exampleSet.getAttributes().getLabel();
        RuleModel ruleModel = new RuleModel(exampleSet);
        TermDetermination termDetermination = new TermDetermination(this.createCriterion());
        ExampleSet trainingSet = (ExampleSet)exampleSet.clone();
        trainingSet.recalculateAttributeStatistics(label);
        while (!this.shouldStop(trainingSet)) {
            SplitCondition term;
            String labelName = this.getNextLabel(trainingSet);
            Rule rule = new Rule(labelName);
            ExampleSet oldTrainingSet = (ExampleSet)trainingSet.clone();
            SplittedExampleSet growPruneSet = new SplittedExampleSet(trainingSet, sampleRatio, 2, this.getParameterAsBoolean("use_local_random_seed"), this.getParameterAsInt("local_random_seed"));
            SplittedExampleSet growingSet = (SplittedExampleSet)growPruneSet.clone();
            growingSet.selectSingleSubset(0);
            SplittedExampleSet pruneSet = (SplittedExampleSet)growPruneSet.clone();
            pruneSet.selectSingleSubset(1);
            int growOldSize = -1;
            ExampleSet growSet = (ExampleSet)growingSet.clone();
            while (growSet.size() > 0 && growSet.size() != growOldSize && !rule.isPure(growSet, pureness) && growSet.getAttributes().size() > 0 && (term = termDetermination.getBestTerm(growSet, labelName)) != null) {
                double unprunedBenefit;
                double prunedBenefit = 0.0;
                if (pruneSet.size() > 0) {
                    prunedBenefit = this.getPruneBenefit(rule, pruneSet);
                }
                rule.addTerm(term);
                if (pruneSet.size() > 0 && (unprunedBenefit = this.getPruneBenefit(rule, pruneSet)) < prunedBenefit - minimalPruneBenefit) {
                    rule.removeLastTerm();
                    break;
                }
                growOldSize = growSet.size();
                Attribute splitAttribute = (growSet = rule.getCovered(growSet)).getAttributes().get(term.getAttributeName());
                if (splitAttribute.isNominal()) {
                    growSet.getAttributes().remove(splitAttribute);
                }
                this.checkForStop();
            }
            if (rule.getTerms().size() <= 0) break;
            growSet = rule.getCovered(trainingSet);
            growSet.recalculateAttributeStatistics(label);
            int[] frequencies = new int[label.getMapping().size()];
            int counter = 0;
            for (String value : label.getMapping().getValues()) {
                frequencies[counter++] = (int)growSet.getStatistics(label, "count", value);
            }
            rule.setFrequencies(frequencies);
            ruleModel.addRule(rule);
            trainingSet = rule.removeCovered(oldTrainingSet);
            trainingSet.recalculateAttributeStatistics(label);
            this.checkForStop();
        }
        if (trainingSet.size() > 0) {
            trainingSet.recalculateAttributeStatistics(label);
            int index = (int)trainingSet.getStatistics(label, "mode");
            String defaultLabel = label.getMapping().mapIndex(index);
            Rule defaultRule = new Rule(defaultLabel);
            int[] frequencies = new int[label.getMapping().size()];
            int counter = 0;
            for (String value : label.getMapping().getValues()) {
                frequencies[counter++] = (int)(trainingSet.getStatistics(label, "count", value) * sampleRatio);
            }
            defaultRule.setFrequencies(frequencies);
            ruleModel.addRule(defaultRule);
        }
        return ruleModel;
    }

    private double getPruneBenefit(Rule rule, ExampleSet exampleSet) {
        Attribute label = exampleSet.getAttributes().getLabel();
        Attribute weight = exampleSet.getAttributes().getWeight();
        double pTotal = 0.0;
        double nTotal = 0.0;
        double p = 0.0;
        double n = 0.0;
        for (Example e : exampleSet) {
            double currentWeight = 1.0;
            if (weight != null) {
                currentWeight = e.getValue(weight);
            }
            if (e.getValue(label) == (double)label.getMapping().getIndex(rule.getLabel())) {
                pTotal += currentWeight;
            } else {
                nTotal += currentWeight;
            }
            if (!rule.coversExample(e)) continue;
            if (e.getValue(label) == (double)label.getMapping().getIndex(rule.getLabel())) {
                p += currentWeight;
                continue;
            }
            n += currentWeight;
        }
        return (p + nTotal - n) / (pTotal + nTotal);
    }

    private String getNextLabel(ExampleSet exampleSet) {
        Attribute label = exampleSet.getAttributes().getLabel();
        int index = (int)exampleSet.getStatistics(label, "mode");
        return label.getMapping().mapIndex(index);
    }

    private boolean shouldStop(ExampleSet exampleSet) {
        for (Terminator terminator : this.terminators) {
            if (!terminator.shouldStop(exampleSet, 0)) continue;
            return true;
        }
        return false;
    }

    private Criterion createCriterion() throws UndefinedParameterError {
        String criterionName = this.getParameterAsString("criterion");
        Class criterionClass = null;
        for (int i = 0; i < CRITERIA_NAMES.length; ++i) {
            if (!CRITERIA_NAMES[i].equals(criterionName)) continue;
            criterionClass = CRITERIA_CLASSES[i];
        }
        if (criterionClass == null && criterionName != null) {
            try {
                criterionClass = Tools.classForName(criterionName);
            }
            catch (ClassNotFoundException e) {
                this.logWarning("Cannot find criterion '" + criterionName + "' and cannot instantiate a class with this name. Using gain ratio criterion instead.");
            }
        }
        if (criterionClass != null) {
            try {
                return (Criterion)criterionClass.newInstance();
            }
            catch (InstantiationException e) {
                this.logWarning("Cannot instantiate criterion class '" + criterionClass.getName() + "'. Using gain ratio criterion instead.");
                return new InfoGainCriterion();
            }
            catch (IllegalAccessException e) {
                this.logWarning("Cannot access criterion class '" + criterionClass.getName() + "'. Using gain ratio criterion instead.");
                return new InfoGainCriterion();
            }
        }
        this.log("No relevance criterion defined, using gain ratio...");
        return new InfoGainCriterion();
    }

    @Override
    public Class<? extends PredictionModel> getModelClass() {
        return RuleModel.class;
    }

    @Override
    public boolean supportsCapability(OperatorCapability capability) {
        switch (capability) {
            case BINOMINAL_ATTRIBUTES: 
            case POLYNOMINAL_ATTRIBUTES: 
            case NUMERICAL_ATTRIBUTES: 
            case POLYNOMINAL_LABEL: 
            case BINOMINAL_LABEL: 
            case WEIGHTED_EXAMPLES: 
            case MISSING_VALUES: {
                return true;
            }
        }
        return false;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterTypeSingle type = new ParameterTypeStringCategory("criterion", "Specifies the used criterion for selecting attributes and numerical splits.", CRITERIA_NAMES, CRITERIA_NAMES[0], false);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble(PARAMETER_SAMPLE_RATIO, "The sample ratio of training data used for growing and pruning.", 0.0, 1.0, 0.9);
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeDouble("pureness", "The desired pureness, i.e. the necessary amount of the major class in a covered subset in order become pure.", 0.0, 1.0, 0.9, false));
        types.add(new ParameterTypeDouble(PARAMETER_MINIMAL_PRUNE_BENEFIT, "The minimum amount of benefit which must be exceeded over unpruned benefit in order to be pruned.", 0.0, 1.0, 0.25));
        types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return types;
    }
}

