/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.tree.criterions;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.learner.tree.FrequencyCalculator;
import com.rapidminer.operator.learner.tree.criterions.AbstractCriterion;

public class AccuracyCriterion
extends AbstractCriterion {
    private FrequencyCalculator frequencyCalculator = new FrequencyCalculator();

    @Override
    public double getNominalBenefit(ExampleSet exampleSet, Attribute attribute) {
        double[][] weightCounts = this.frequencyCalculator.getNominalWeightCounts(exampleSet, attribute);
        return this.getBenefit(weightCounts);
    }

    @Override
    public double getNumericalBenefit(ExampleSet exampleSet, Attribute attribute, double splitValue) {
        double[][] weightCounts = this.frequencyCalculator.getNumericalWeightCounts(exampleSet, attribute, splitValue);
        return this.getBenefit(weightCounts);
    }

    @Override
    public double getBenefit(double[][] weightCounts) {
        double sum = 0.0;
        for (int v = 0; v < weightCounts.length; ++v) {
            int maxIndex = -1;
            double maxValue = Double.NEGATIVE_INFINITY;
            double currentSum = 0.0;
            for (int l = 0; l < weightCounts[v].length; ++l) {
                if (weightCounts[v][l] > maxValue) {
                    maxIndex = l;
                    maxValue = weightCounts[v][l];
                }
                currentSum += weightCounts[v][l];
            }
            sum += weightCounts[v][maxIndex] / currentSum;
        }
        return sum;
    }

    @Override
    public boolean supportsIncrementalCalculation() {
        return true;
    }

    @Override
    public double getIncrementalBenefit() {
        int maxIndex = -1;
        double maxValue = Double.NEGATIVE_INFINITY;
        double currentSum = 0.0;
        for (int j = 0; j < this.leftLabelWeights.length; ++j) {
            if (this.leftLabelWeights[j] > maxValue) {
                maxIndex = j;
                maxValue = this.leftLabelWeights[j];
            }
            currentSum += this.leftLabelWeights[j];
        }
        double sum = this.leftLabelWeights[maxIndex] / currentSum;
        maxIndex = -1;
        maxValue = Double.NEGATIVE_INFINITY;
        currentSum = 0.0;
        for (int j = 0; j < this.rightLabelWeights.length; ++j) {
            if (this.rightLabelWeights[j] > maxValue) {
                maxIndex = j;
                maxValue = this.rightLabelWeights[j];
            }
            currentSum += this.rightLabelWeights[j];
        }
        return sum + this.rightLabelWeights[maxIndex] / currentSum;
    }
}

