/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.tree.criterions;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.learner.tree.FrequencyCalculator;
import com.rapidminer.operator.learner.tree.criterions.AbstractCriterion;

public class GiniIndexCriterion
extends AbstractCriterion {
    private FrequencyCalculator frequencyCalculator = new FrequencyCalculator();

    @Override
    public double getNominalBenefit(ExampleSet exampleSet, Attribute attribute) {
        double[][] weightCounts = this.frequencyCalculator.getNominalWeightCounts(exampleSet, attribute);
        return this.getBenefit(weightCounts);
    }

    @Override
    public double getNumericalBenefit(ExampleSet exampleSet, Attribute attribute, double splitValue) {
        double[][] weightCounts = this.frequencyCalculator.getNumericalWeightCounts(exampleSet, attribute, splitValue);
        return this.getBenefit(weightCounts);
    }

    @Override
    public double getBenefit(double[][] weightCounts) {
        double[] classWeights = new double[weightCounts[0].length];
        for (int l = 0; l < classWeights.length; ++l) {
            for (int v = 0; v < weightCounts.length; ++v) {
                int n = l;
                classWeights[n] = classWeights[n] + weightCounts[v][l];
            }
        }
        double totalClassWeight = this.frequencyCalculator.getTotalWeight(classWeights);
        double totalEntropy = this.getGiniIndex(classWeights, totalClassWeight);
        double gain = 0.0;
        for (int v = 0; v < weightCounts.length; ++v) {
            double[] partitionWeights = weightCounts[v];
            double partitionWeight = this.frequencyCalculator.getTotalWeight(partitionWeights);
            gain += this.getGiniIndex(partitionWeights, partitionWeight) * partitionWeight / totalClassWeight;
        }
        return totalEntropy - gain;
    }

    private double getGiniIndex(double[] labelWeights, double totalWeight) {
        double sum = 0.0;
        for (int i = 0; i < labelWeights.length; ++i) {
            double frequency = labelWeights[i] / totalWeight;
            sum += frequency * frequency;
        }
        return 1.0 - sum;
    }

    @Override
    public boolean supportsIncrementalCalculation() {
        return true;
    }

    @Override
    public double getIncrementalBenefit() {
        double totalGiniEntropy = this.getGiniIndex(this.totalLabelWeights, this.totalWeight);
        double gain = this.getGiniIndex(this.leftLabelWeights, this.leftWeight) * this.leftWeight / this.totalWeight;
        return totalGiniEntropy - (gain += this.getGiniIndex(this.rightLabelWeights, this.rightWeight) * this.rightWeight / this.totalWeight);
    }
}

