/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.splitLearner.tree;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.example.NominalAttributeStatistics;
import edu.udo.cs.yale.operator.learner.splitLearner.tree.AbstractTreeLearner;
import edu.udo.cs.yale.operator.learner.splitLearner.tree.TreeNode;
import edu.udo.cs.yale.operator.learner.splitLearner.tree.TreePruner;
import edu.udo.cs.yale.tools.math.MathFunctions;
import java.util.Iterator;

public class PessimisticPruner
implements TreePruner {
    private double confidenceLevel;
    private double prunePreferenceLevel;
    private AbstractTreeLearner learner;

    public PessimisticPruner(double confidenceLevel, double prunePreferenceLevel, AbstractTreeLearner learner) {
        this.confidenceLevel = confidenceLevel;
        this.prunePreferenceLevel = prunePreferenceLevel;
        this.learner = learner;
    }

    public void prune(TreeNode root) {
        Iterator<TreeNode> childIterator = root.childIterator();
        while (childIterator.hasNext()) {
            this.pruneChild(childIterator.next(), root);
        }
    }

    private void pruneChild(TreeNode currentNode, TreeNode father) {
        if (currentNode.hasChildren()) {
            Iterator<TreeNode> childIterator = currentNode.childIterator();
            while (childIterator.hasNext()) {
                this.pruneChild(childIterator.next(), currentNode);
            }
            if (!this.childrenHaveChildren(currentNode)) {
                double currentErrorRate;
                double leafsErrorEstimate = 0.0;
                childIterator = currentNode.childIterator();
                while (childIterator.hasNext()) {
                    TreeNode leafNode = childIterator.next();
                    ExampleSet leafExampleSet = leafNode.getTrainExampleSet();
                    int examples = leafExampleSet.size();
                    double currentErrorRate2 = (double)this.getErrorNumber(leafExampleSet, leafNode.getLabel()) / (double)leafExampleSet.size();
                    leafsErrorEstimate += this.pessimisticErrors(examples, currentErrorRate2, this.confidenceLevel) * ((double)examples / (double)currentNode.getTrainExampleSet().size());
                }
                ExampleSet currentNodeExampleSet = currentNode.getTrainExampleSet();
                double currentNodeLabel = this.prunedLabel(currentNodeExampleSet);
                int examples = currentNodeExampleSet.size();
                double nodeErrorEstimate = this.pessimisticErrors(examples, currentErrorRate = (double)this.getErrorNumber(currentNodeExampleSet, currentNodeLabel) / (double)currentNodeExampleSet.size(), this.confidenceLevel);
                if (nodeErrorEstimate - this.prunePreferenceLevel <= leafsErrorEstimate) {
                    this.learner.setLabel(currentNode, currentNodeExampleSet.getAttributes().getLabel(), currentNodeLabel, currentNodeExampleSet);
                }
            }
        }
    }

    private boolean childrenHaveChildren(TreeNode node) {
        boolean result = false;
        Iterator<TreeNode> iterator = node.childIterator();
        while (iterator.hasNext()) {
            result |= iterator.next().hasChildren();
        }
        return result;
    }

    private int getErrorNumber(ExampleSet exampleSet, double label) {
        int errors = 0;
        Iterator iterator = exampleSet.iterator();
        while (iterator.hasNext()) {
            if (((Example)iterator.next()).getLabel() == label) continue;
            ++errors;
        }
        return errors;
    }

    public double prunedLabel(ExampleSet exampleSet) {
        Attribute labelAttribute = exampleSet.getAttributes().getLabel();
        exampleSet.recalculateAttributeStatistics(labelAttribute);
        double test = ((NominalAttributeStatistics)labelAttribute.getStatistics()).getMode();
        return test;
    }

    public double pessimisticErrors(double numberOfExamples, double errorRate, double confidenceLevel) {
        if (errorRate < 1.0E-6) {
            return errorRate + numberOfExamples * (1.0 - Math.exp(Math.log(confidenceLevel) / numberOfExamples));
        }
        if (errorRate + 0.5 >= numberOfExamples) {
            return errorRate + 0.67 * (numberOfExamples - errorRate);
        }
        double coefficient = MathFunctions.normalInverse(1.0 - confidenceLevel);
        coefficient *= coefficient;
        double pessimisticRate = (errorRate + 0.5 + coefficient / 2.0 + Math.sqrt(coefficient * ((errorRate + 0.5) * (1.0 - (errorRate + 0.5) / numberOfExamples) + coefficient / 4.0))) / (numberOfExamples + coefficient);
        return numberOfExamples * pessimisticRate;
    }
}

