/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.tree.criterions;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.learner.tree.FrequencyCalculator;
import com.rapidminer.operator.learner.tree.criterions.InfoGainCriterion;

public class GainRatioCriterion
extends InfoGainCriterion {
    private static double LOG_FACTOR = 1.0 / Math.log(2.0);
    private FrequencyCalculator frequencyCalculator = new FrequencyCalculator();

    public GainRatioCriterion() {
    }

    public GainRatioCriterion(double minimalGain) {
        super(minimalGain);
    }

    @Override
    public double getNominalBenefit(ExampleSet exampleSet, Attribute attribute) {
        double[][] weightCounts = this.frequencyCalculator.getNominalWeightCounts(exampleSet, attribute);
        return this.getBenefit(weightCounts);
    }

    @Override
    public double getNumericalBenefit(ExampleSet exampleSet, Attribute attribute, double splitValue) {
        double[][] weightCounts = this.frequencyCalculator.getNumericalWeightCounts(exampleSet, attribute, splitValue);
        return this.getBenefit(weightCounts);
    }

    @Override
    public double getBenefit(double[][] weightCounts) {
        double gain = super.getBenefit(weightCounts);
        double splitInfo = this.getSplitInfo(weightCounts);
        if (splitInfo == 0.0) {
            return gain;
        }
        return gain / splitInfo;
    }

    protected double getSplitInfo(double[][] weightCounts) {
        double[] splitCounts = new double[weightCounts.length];
        for (int v = 0; v < weightCounts.length; ++v) {
            for (int l = 0; l < weightCounts[v].length; ++l) {
                int n = v;
                splitCounts[n] = splitCounts[n] + weightCounts[v][l];
            }
        }
        double totalSplitCount = 0.0;
        for (double w : splitCounts) {
            totalSplitCount += w;
        }
        double splitInfo = 0.0;
        for (int v = 0; v < splitCounts.length; ++v) {
            if (!(splitCounts[v] > 0.0)) continue;
            double proportion = splitCounts[v] / totalSplitCount;
            splitInfo -= Math.log(proportion) * LOG_FACTOR * proportion;
        }
        return splitInfo;
    }

    protected double getSplitInfo(double[] partitionWeights, double totalWeight) {
        double splitInfo = 0.0;
        for (double partitionWeight : partitionWeights) {
            if (!(partitionWeight > 0.0)) continue;
            double partitionProportion = partitionWeight / totalWeight;
            splitInfo += partitionProportion * Math.log(partitionProportion) * LOG_FACTOR;
        }
        return -splitInfo;
    }

    @Override
    public boolean supportsIncrementalCalculation() {
        return true;
    }

    @Override
    public double getIncrementalBenefit() {
        double gain = this.getEntropy(this.totalLabelWeights, this.totalWeight);
        gain -= this.getEntropy(this.leftLabelWeights, this.leftWeight) * this.leftWeight / this.totalWeight;
        gain -= this.getEntropy(this.rightLabelWeights, this.rightWeight) * this.rightWeight / this.totalWeight;
        double splitInfo = this.getSplitInfo(new double[]{this.leftWeight, this.rightWeight}, this.totalWeight);
        if (splitInfo == 0.0) {
            return gain;
        }
        return gain / splitInfo;
    }
}

