/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.tree.criterions;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.learner.tree.FrequencyCalculator;
import com.rapidminer.operator.learner.tree.MinimalGainHandler;
import com.rapidminer.operator.learner.tree.criterions.AbstractCriterion;

public class InfoGainCriterion
extends AbstractCriterion
implements MinimalGainHandler {
    private static double LOG_FACTOR = 1.0 / Math.log(2.0);
    private FrequencyCalculator frequencyCalculator = new FrequencyCalculator();
    private double minimalGain = 0.1;

    public InfoGainCriterion() {
    }

    public InfoGainCriterion(double minimalGain) {
        this.minimalGain = minimalGain;
    }

    @Override
    public void setMinimalGain(double minimalGain) {
        this.minimalGain = minimalGain;
    }

    @Override
    public double getNominalBenefit(ExampleSet exampleSet, Attribute attribute) {
        double[][] weightCounts = this.frequencyCalculator.getNominalWeightCounts(exampleSet, attribute);
        return this.getBenefit(weightCounts);
    }

    @Override
    public double getNumericalBenefit(ExampleSet exampleSet, Attribute attribute, double splitValue) {
        double[][] weightCounts = this.frequencyCalculator.getNumericalWeightCounts(exampleSet, attribute, splitValue);
        return this.getBenefit(weightCounts);
    }

    @Override
    public double getBenefit(double[][] weightCounts) {
        int numberOfValues = weightCounts.length;
        int numberOfLabels = weightCounts[0].length;
        double[] entropies = new double[numberOfValues];
        double[] totalWeights = new double[numberOfValues];
        for (int v = 0; v < numberOfValues; ++v) {
            int l;
            for (l = 0; l < numberOfLabels; ++l) {
                int n = v;
                totalWeights[n] = totalWeights[n] + weightCounts[v][l];
            }
            for (l = 0; l < numberOfLabels; ++l) {
                if (!(weightCounts[v][l] > 0.0)) continue;
                double proportion = weightCounts[v][l] / totalWeights[v];
                int n = v;
                entropies[n] = entropies[n] - Math.log(proportion) * LOG_FACTOR * proportion;
            }
        }
        double totalWeight = 0.0;
        for (double w : totalWeights) {
            totalWeight += w;
        }
        double information = 0.0;
        for (int v = 0; v < numberOfValues; ++v) {
            information += totalWeights[v] / totalWeight * entropies[v];
        }
        double[] classWeights = new double[numberOfLabels];
        for (int l = 0; l < numberOfLabels; ++l) {
            for (int v = 0; v < numberOfValues; ++v) {
                int n = l;
                classWeights[n] = classWeights[n] + weightCounts[v][l];
            }
        }
        double totalClassWeight = 0.0;
        for (double w : classWeights) {
            totalClassWeight += w;
        }
        double classEntropy = 0.0;
        for (int l = 0; l < numberOfLabels; ++l) {
            if (!(classWeights[l] > 0.0)) continue;
            double proportion = classWeights[l] / totalClassWeight;
            classEntropy -= Math.log(proportion) * LOG_FACTOR * proportion;
        }
        double informationGain = classEntropy - information;
        if (informationGain < this.minimalGain * classEntropy) {
            informationGain = 0.0;
        }
        return informationGain;
    }

    protected double getEntropy(double[] labelWeights, double totalWeight) {
        double entropy = 0.0;
        for (int i = 0; i < labelWeights.length; ++i) {
            if (!(labelWeights[i] > 0.0)) continue;
            double proportion = labelWeights[i] / totalWeight;
            entropy -= Math.log(proportion) * LOG_FACTOR * proportion;
        }
        return entropy;
    }

    @Override
    public boolean supportsIncrementalCalculation() {
        return true;
    }

    @Override
    public double getIncrementalBenefit() {
        double totalEntropy = this.getEntropy(this.totalLabelWeights, this.totalWeight);
        double gain = this.getEntropy(this.leftLabelWeights, this.leftWeight) * this.leftWeight / this.totalWeight;
        return totalEntropy - (gain += this.getEntropy(this.rightLabelWeights, this.rightWeight) * this.rightWeight / this.totalWeight);
    }
}

