/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.splitLearner.tree;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.Model;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.learner.splitLearner.tree.AbstractTreeLearner;
import edu.udo.cs.yale.operator.learner.splitLearner.tree.RegressionTreeLearnerChain;
import edu.udo.cs.yale.operator.learner.splitLearner.tree.RegressionTreeNode;
import edu.udo.cs.yale.operator.learner.splitLearner.tree.TreeNode;
import edu.udo.cs.yale.operator.learner.splitLearner.tree.TreePruner;
import edu.udo.cs.yale.tools.LogService;
import java.util.Iterator;

public class RegressionPruner
implements TreePruner {
    private double complexityWeight;
    private int numberOfNodes = 0;
    private int prunedNodes = 0;
    private double examplesCoveredMax;
    private RegressionTreeLearnerChain chain;

    public RegressionPruner(double complexityWeight, RegressionTreeLearnerChain chain, AbstractTreeLearner learner) {
        this.complexityWeight = complexityWeight;
        this.chain = chain;
    }

    public void prune(TreeNode root) {
        ExampleSet rootSet = root.getTrainExampleSet();
        this.examplesCoveredMax = rootSet.size();
        Iterator<TreeNode> childIterator = root.childIterator();
        while (childIterator.hasNext()) {
            this.pruneChild((RegressionTreeNode)childIterator.next(), (RegressionTreeNode)root);
        }
        LogService.logMessage("Prunned " + this.prunedNodes + " Nodes of " + this.numberOfNodes, 2);
    }

    private void pruneChild(RegressionTreeNode currentNode, RegressionTreeNode father) {
        ++this.numberOfNodes;
        if (currentNode.hasChildren()) {
            Iterator<TreeNode> childIterator = currentNode.childIterator();
            while (childIterator.hasNext()) {
                this.pruneChild((RegressionTreeNode)childIterator.next(), currentNode);
            }
            boolean childrenAreLeafs = true;
            childIterator = currentNode.childIterator();
            while (childIterator.hasNext()) {
                childrenAreLeafs &= !childIterator.next().hasAdditionalGraphic();
            }
            if (childrenAreLeafs) {
                double leafsComplexity = 1.0;
                double leafsMeanSquareError = 0.0;
                double relativeError = 0.0;
                childIterator = currentNode.childIterator();
                while (childIterator.hasNext()) {
                    RegressionTreeNode leafNode = (RegressionTreeNode)childIterator.next();
                    ExampleSet leafExampleSet = leafNode.getTrainExampleSet();
                    this.applyModel(leafNode.getRegressionModel(), leafExampleSet);
                    leafsComplexity -= Math.pow((double)leafExampleSet.size() / this.examplesCoveredMax, 2.0);
                    leafsMeanSquareError += this.calculateMeanSquareError(leafExampleSet);
                    relativeError += this.calculateRelativeError(leafExampleSet);
                }
                double leafsMeanRelativeError = relativeError / (double)currentNode.getTrainExampleSet().size();
                double nodeComplexity = 1.0;
                ExampleSet currentNodeExampleSet = currentNode.getTrainExampleSet();
                currentNode.setRegressionModel(this.calculateModel(currentNodeExampleSet));
                this.applyModel(currentNode.getRegressionModel(), currentNodeExampleSet);
                nodeComplexity -= Math.pow((double)currentNodeExampleSet.size() / this.examplesCoveredMax, 2.0);
                double nodeMeanRelativeError = this.calculateRelativeError(currentNodeExampleSet) / (double)currentNodeExampleSet.size();
                if (nodeMeanRelativeError - this.complexityWeight <= leafsMeanRelativeError) {
                    this.prunedNodes += currentNode.getNumberOfChildren();
                    currentNode.pruneChildren();
                    currentNode.setStringLabel("Regression Model (P)");
                }
            }
        }
    }

    private void applyModel(Model model, ExampleSet set) {
        try {
            model.apply(set);
        }
        catch (OperatorException e) {
            e.printStackTrace();
        }
    }

    private Model calculateModel(ExampleSet set) {
        return this.chain.estimateRegressionModel(set);
    }

    private double calculateMeanSquareError(ExampleSet set) {
        double nodeDeviationSum = 0.0;
        Attribute prediction = set.getAttributes().getPredictedLabel();
        Attribute label = set.getAttributes().getLabel();
        for (Example example : set) {
            nodeDeviationSum += Math.pow(example.getValue(label) - example.getValue(prediction), 2.0);
        }
        return nodeDeviationSum;
    }

    private double calculateRelativeError(ExampleSet set) {
        double relativeError = 0.0;
        Attribute prediction = set.getAttributes().getPredictedLabel();
        Attribute label = set.getAttributes().getLabel();
        for (Example example : set) {
            relativeError += Math.abs(example.getValue(label) - example.getValue(prediction)) / example.getValue(label);
        }
        return relativeError;
    }
}

