/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.splitLearner.tree;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.example.Partition;
import edu.udo.cs.yale.example.SplittedExampleSet;
import edu.udo.cs.yale.operator.learner.splitLearner.AbstractSplitter;
import edu.udo.cs.yale.operator.learner.splitLearner.AttributeManager;
import edu.udo.cs.yale.operator.learner.splitLearner.NumericalSplit;
import edu.udo.cs.yale.operator.learner.splitLearner.Split;
import edu.udo.cs.yale.operator.learner.splitLearner.SplitCondition;
import edu.udo.cs.yale.operator.learner.splitLearner.SplitCriterion;
import edu.udo.cs.yale.operator.learner.splitLearner.SplitTripel;
import edu.udo.cs.yale.operator.learner.splitLearner.tree.EqualNodeCondition;
import edu.udo.cs.yale.operator.learner.splitLearner.tree.GreaterEqualNodeCondition;
import edu.udo.cs.yale.operator.learner.splitLearner.tree.LessNodeCondition;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;

public class RegressionSplitter
extends AbstractSplitter {
    SplitCriterion criterion;
    private int minimalSetSize;

    public RegressionSplitter(SplitCriterion criterion, int minimalSetSize) {
        this.criterion = criterion;
        this.minimalSetSize = minimalSetSize;
    }

    public Split getBestSplit(ExampleSet exampleSet, AttributeManager attributeManager) {
        Attribute labelAttribute = exampleSet.getAttributes().getLabel();
        attributeManager.reset();
        NumericalSplit split = new NumericalSplit();
        while (attributeManager.hasNext()) {
            Attribute currentAttribute = attributeManager.next();
            if (!currentAttribute.isNominal()) {
                int i = 0;
                Object[] valueLabelPositions = new SplitTripel[exampleSet.size()];
                for (Example example : exampleSet) {
                    valueLabelPositions[i] = new SplitTripel(example.getValue(currentAttribute), example.getLabel(), i);
                    ++i;
                }
                Arrays.sort(valueLabelPositions);
                int[] partition = new int[exampleSet.size()];
                Arrays.fill(partition, 1);
                int preOrder = 0;
                while (preOrder < this.minimalSetSize) {
                    partition[((SplitTripel)valueLabelPositions[preOrder]).getPosition()] = 0;
                    ++preOrder;
                }
                double completeStandardDeviation = this.getStandardDeviation(exampleSet, labelAttribute, this.getMean(exampleSet, labelAttribute));
                int rightCount = exampleSet.size() - this.minimalSetSize;
                int completeCount = exampleSet.size();
                int leftCount = this.minimalSetSize;
                while (leftCount < valueLabelPositions.length - this.minimalSetSize) {
                    partition[((SplitTripel)valueLabelPositions[leftCount]).getPosition()] = 0;
                    Partition leftPartition = new Partition(partition, 2);
                    Partition rightPartition = new Partition(partition, 2);
                    SplittedExampleSet leftSet = new SplittedExampleSet(exampleSet, leftPartition);
                    SplittedExampleSet rightSet = new SplittedExampleSet(exampleSet, rightPartition);
                    leftSet.selectSingleSubset(0);
                    rightSet.selectSingleSubset(1);
                    double benefit = completeStandardDeviation;
                    benefit -= this.getStandardDeviation(leftSet, labelAttribute, this.getMean(leftSet, labelAttribute)) * ((double)leftCount + 1.0) / (double)completeCount;
                    if (split.testSplit(benefit -= this.getStandardDeviation(rightSet, labelAttribute, this.getMean(rightSet, labelAttribute)) * (double)(--rightCount) / (double)completeCount)) {
                        ArrayList<SplitCondition> conditions = new ArrayList<SplitCondition>(2);
                        conditions.add(new LessNodeCondition(currentAttribute, ((SplitTripel)valueLabelPositions[leftCount]).getValue()));
                        conditions.add(new GreaterEqualNodeCondition(currentAttribute, ((SplitTripel)valueLabelPositions[leftCount]).getValue()));
                        split.setSplit(currentAttribute, true, benefit, partition, null, 2, conditions);
                    }
                    ++leftCount;
                }
                continue;
            }
            int completeCount = exampleSet.size();
            int[] partitionNumbers = new int[completeCount];
            LinkedHashMap<Double, Integer> partitionMapping = new LinkedHashMap<Double, Integer>();
            int index = 0;
            for (Example example : exampleSet) {
                double currentValue = example.getValue(currentAttribute);
                if (partitionMapping.containsKey(currentValue)) {
                    partitionNumbers[index] = (Integer)partitionMapping.get(currentValue);
                } else {
                    partitionMapping.put(currentValue, partitionMapping.size());
                    partitionNumbers[index] = partitionMapping.size() - 1;
                }
                ++index;
            }
            int numberOfValues = currentAttribute.getMapping().size();
            Partition partition = new Partition(partitionNumbers, numberOfValues);
            SplittedExampleSet splitSet = new SplittedExampleSet(exampleSet, partition);
            double benefit = this.getStandardDeviation(exampleSet, labelAttribute, this.getMean(exampleSet, labelAttribute));
            int i = 0;
            while (i < numberOfValues) {
                splitSet.selectSingleSubset(i);
                benefit -= this.getStandardDeviation(splitSet, labelAttribute, this.getMean(splitSet, labelAttribute)) * ((double)splitSet.size() + 1.0) / (double)completeCount;
                ++i;
            }
            if (!split.testSplit(benefit)) continue;
            ArrayList<SplitCondition> conditions = new ArrayList<SplitCondition>();
            int valueIndex = 0;
            while (valueIndex < partitionMapping.size()) {
                conditions.add(new EqualNodeCondition(currentAttribute, ((Integer)partitionMapping.get(valueIndex)).intValue()));
                ++valueIndex;
            }
            split.setSplit(currentAttribute, false, benefit, partitionNumbers, null, numberOfValues, conditions);
        }
        return split;
    }

    private double getStandardDeviation(ExampleSet exampleSet, Attribute attribute, double mean) {
        double variance = 0.0;
        for (Example example : exampleSet) {
            variance += Math.pow(example.getValue(attribute) - mean, 2.0);
        }
        return Math.sqrt(variance / (double)exampleSet.size());
    }

    private double getMean(ExampleSet exampleSet, Attribute attribute) {
        double sum = 0.0;
        for (Example example : exampleSet) {
            sum += example.getValue(attribute);
        }
        return sum / (double)exampleSet.size();
    }
}

