/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.tree;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.SimplePredictionModel;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.Tools;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;

public class MultiCriterionDecisionStumps
extends AbstractLearner {
    private static final String ACC = "accuracy";
    private static final String ENTROPY = "entropy";
    private static final String SQRT_PN = "sqrt(TP*FP) + sqrt(FN*TN)";
    private static final String GINI = "gini index";
    private static final String CHI_SQUARE = "chi square test";
    private static final String[] UTILITY_FUNCTION_LIST = new String[]{"entropy", "accuracy", "sqrt(TP*FP) + sqrt(FN*TN)", "gini index", "chi square test"};
    private static final String PARAMETER_UTILITY_FUNCTION = "utility_function";
    private int posIndex;
    private double globalP;
    private double globalN;
    private Model bestModel;
    private double bestScore;
    private String utilityFunction;

    public MultiCriterionDecisionStumps(OperatorDescription description) {
        super(description);
    }

    @Override
    public Class<? extends PredictionModel> getModelClass() {
        return DecisionStumpModel.class;
    }

    @Override
    public boolean supportsCapability(OperatorCapability lc) {
        if (lc == OperatorCapability.NUMERICAL_ATTRIBUTES) {
            return true;
        }
        if (lc == OperatorCapability.POLYNOMINAL_ATTRIBUTES) {
            return true;
        }
        if (lc == OperatorCapability.BINOMINAL_ATTRIBUTES) {
            return true;
        }
        if (lc == OperatorCapability.BINOMINAL_LABEL) {
            return true;
        }
        return lc == OperatorCapability.WEIGHTED_EXAMPLES;
    }

    protected void initHighscore() {
        this.bestModel = null;
        this.bestScore = Double.NEGATIVE_INFINITY;
    }

    protected Model getBestModel() {
        return this.bestModel;
    }

    private void setBestModel(DecisionStumpModel model, double score) {
        this.bestModel = model;
        this.bestScore = score;
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        this.utilityFunction = UTILITY_FUNCTION_LIST[this.getParameterAsInt(PARAMETER_UTILITY_FUNCTION)];
        this.initHighscore();
        this.posIndex = exampleSet.getAttributes().getLabel().getMapping().getPositiveIndex();
        double[] globalCounts = this.computePriors(exampleSet);
        this.globalP = globalCounts[0];
        this.globalN = globalCounts[1];
        boolean defaultModelPrecition = this.getScore(globalCounts, true) >= this.getScore(globalCounts, false);
        this.setBestModel(new DecisionStumpModel(null, 0.0, exampleSet, defaultModelPrecition, true), this.getScore(globalCounts, defaultModelPrecition));
        this.evaluateNominalAttributes(exampleSet);
        this.evaluateNumericalAttributes(exampleSet);
        return this.getBestModel();
    }

    private void evaluateNumericalAttributes(ExampleSet exampleSet) throws OperatorException {
        int numAttr = exampleSet.getAttributes().size();
        int[] mapAttribToIndex = new int[numAttr];
        Attribute[] mapIndexToAttrib = new Attribute[numAttr];
        int index = 0;
        int i = 0;
        for (Attribute attribute : exampleSet.getAttributes()) {
            if (!attribute.isNominal()) {
                mapIndexToAttrib[index] = attribute;
                mapAttribToIndex[i] = index++;
            } else {
                mapAttribToIndex[i] = -1;
            }
            ++i;
        }
        if (index == 0) {
            return;
        }
        boolean hasWeight = exampleSet.getAttributes().getWeight() != null;
        double[][] weightedLabel = new double[exampleSet.size()][2];
        double[][][] values = new double[index][exampleSet.size()][];
        Iterator reader = exampleSet.iterator();
        int exampleNum = 0;
        double[] weightedPriors = new double[2];
        while (reader.hasNext()) {
            Example example = (Example)reader.next();
            int label = example.getLabel() == (double)this.posIndex ? 0 : 1;
            double weight = hasWeight ? example.getWeight() : 1.0;
            int n = label;
            weightedPriors[n] = weightedPriors[n] + weight;
            weightedLabel[exampleNum] = new double[]{label, weight};
            for (int i2 = 0; i2 < index; ++i2) {
                double attribValue = example.getValue(mapIndexToAttrib[i2]);
                values[i2][exampleNum] = new double[]{attribValue, exampleNum};
            }
            ++exampleNum;
        }
        boolean predictNaN = weightedPriors[0] >= weightedPriors[1];
        Comparator<double[]> cmp = new Comparator<double[]>(){

            @Override
            public int compare(double[] arg0, double[] arg1) {
                return Double.compare(arg0[0], arg1[0]);
            }
        };
        for (int i3 = 0; i3 < index; ++i3) {
            double curAttribValue;
            Attribute currentAttribute = mapIndexToAttrib[i3];
            double[][] currentAttributeValues = values[i3];
            Arrays.sort(currentAttributeValues, cmp);
            double[] counts = new double[exampleSet.getAttributes().getLabel().getMapping().size()];
            double lastValue = Double.NEGATIVE_INFINITY;
            double lastScore = Double.NEGATIVE_INFINITY;
            boolean betterPrediction = false;
            for (int j = 0; j < currentAttributeValues.length && !Double.isNaN(curAttribValue = currentAttributeValues[j][0]) && curAttribValue != Double.POSITIVE_INFINITY; ++j) {
                int curExampleNumber = (int)currentAttributeValues[j][1];
                int curLabel = (int)weightedLabel[curExampleNumber][0];
                double curWeight = weightedLabel[curExampleNumber][1];
                if (curAttribValue != lastValue && lastScore > this.bestScore) {
                    double testValue = (curAttribValue + lastValue) / 2.0;
                    boolean includeNaNs = predictNaN == betterPrediction;
                    DecisionStumpModel dsm = new DecisionStumpModel(currentAttribute, testValue, exampleSet, betterPrediction, includeNaNs);
                    this.setBestModel(dsm, lastScore);
                }
                int n = curLabel;
                counts[n] = counts[n] + curWeight;
                double scorePos = this.getScore(counts, true);
                double scoreNeg = this.getScore(counts, false);
                lastScore = Math.max(scorePos, scoreNeg);
                betterPrediction = scorePos >= scoreNeg;
                lastValue = curAttribValue;
            }
        }
    }

    private void evaluateNominalAttributes(ExampleSet exampleSet) throws OperatorException {
        int numAttr = exampleSet.getAttributes().size();
        int[] mapAttribToIndex = new int[numAttr];
        Attribute[] mapIndexToAttrib = new Attribute[numAttr];
        int index = 0;
        int i = 0;
        for (Attribute attribute : exampleSet.getAttributes()) {
            if (attribute.isNominal()) {
                mapIndexToAttrib[index] = attribute;
                mapAttribToIndex[i] = index++;
            } else {
                mapAttribToIndex[i] = -1;
            }
            ++i;
        }
        if (index == 0) {
            return;
        }
        double[][][] counter = new double[index][][];
        double[][] countNaNs = new double[index][exampleSet.getAttributes().getLabel().getMapping().size()];
        for (int i2 = 0; i2 < index; ++i2) {
            int numValues = mapIndexToAttrib[i2].getMapping().size();
            counter[i2] = new double[numValues][exampleSet.getAttributes().getLabel().getMapping().size()];
        }
        Attribute weightAttr = exampleSet.getAttributes().getWeight();
        for (Example example : exampleSet) {
            double weight = weightAttr == null ? 1.0 : example.getWeight();
            int label = example.getLabel() == (double)this.posIndex ? 0 : 1;
            for (int i3 = 0; i3 < index; ++i3) {
                double attributeValue = example.getValue(mapIndexToAttrib[i3]);
                if (Double.isNaN(attributeValue)) {
                    double[] dArray = countNaNs[i3];
                    int n = label;
                    dArray[n] = dArray[n] + weight;
                    continue;
                }
                double[] dArray = counter[i3][(int)attributeValue];
                int n = label;
                dArray[n] = dArray[n] + weight;
            }
        }
        for (int i4 = 0; i4 < index; ++i4) {
            double[][] attributeMatrix = counter[i4];
            for (int j = 0; j < attributeMatrix.length; ++j) {
                ScoreNaNInfo snp = this.getScore(attributeMatrix[j], countNaNs[i4]);
                if (!(snp.score > this.bestScore)) continue;
                Attribute attribute = mapIndexToAttrib[i4];
                double testValue = j;
                this.setBestModel(new DecisionStumpModel(attribute, testValue, exampleSet, snp.predicted, snp.includeNaNs), snp.score);
            }
        }
    }

    private ScoreNaNInfo getScore(double[] counts, double[] countNaNs) throws UndefinedParameterError {
        double score = this.getScore(counts, true);
        ScoreNaNInfo snp = new ScoreNaNInfo(score, false, true);
        score = this.getScore(counts, false);
        ScoreNaNInfo snp2 = new ScoreNaNInfo(score, false, false);
        snp = snp.max(snp2);
        if (countNaNs[0] > 0.0 || countNaNs[1] > 0.0) {
            counts[0] = counts[0] + countNaNs[0];
            counts[1] = counts[1] + countNaNs[1];
            score = this.getScore(counts, true);
            snp2 = new ScoreNaNInfo(score, true, true);
            snp = snp.max(snp2);
            score = this.getScore(counts, false);
            snp2 = new ScoreNaNInfo(score, true, false);
            snp = snp.max(snp2);
        }
        return snp;
    }

    protected double getScore(double[] counts, boolean predictPositives) {
        double score;
        double p = counts[0];
        double n = counts[1];
        if (this.utilityFunction.equals(ACC)) {
            score = predictPositives ? p - n : n - p;
        } else if (this.utilityFunction.equals(ENTROPY)) {
            if (p - n >= 0.0 ^ predictPositives) {
                return Double.NEGATIVE_INFINITY;
            }
            double cov = p + n;
            double uncov = this.globalP + this.globalN - cov;
            double scoreCovered = cov == 0.0 ? 0.0 : this.entropyLog2(p / cov) + this.entropyLog2(n / cov);
            double scoreUncovered = uncov == 0.0 ? 0.0 : this.entropyLog2((this.globalP - p) / uncov) + this.entropyLog2((this.globalN - n) / uncov);
            score = (cov * scoreCovered + uncov * scoreUncovered) / (cov + uncov);
            score = -score;
        } else if (this.utilityFunction.equals(SQRT_PN)) {
            if (p - n >= 0.0 ^ predictPositives) {
                return Double.NEGATIVE_INFINITY;
            }
            score = Math.sqrt(p * n) + Math.sqrt((this.globalP - p) * (this.globalN - n));
            score = -score;
        } else if (this.utilityFunction.equals(GINI)) {
            if (p - n >= 0.0 ^ predictPositives) {
                return Double.NEGATIVE_INFINITY;
            }
            double cov = p + n;
            double uncov = this.globalP + this.globalN - cov;
            double scoreCovered = cov == 0.0 ? 0.0 : p / cov * (n / cov);
            double scoreUncovered = uncov == 0.0 ? 0.0 : (this.globalP - p) / uncov * ((this.globalN - n) / uncov);
            score = (cov * scoreCovered + uncov * scoreUncovered) / (cov + uncov);
            score = -score;
        } else if (this.utilityFunction.equals(CHI_SQUARE)) {
            double q = this.globalP - p;
            double r = this.globalN - n;
            double cov = p + n;
            double uncov = q + r;
            double total = cov + uncov;
            double c11 = cov * this.globalP / total;
            double c12 = cov * this.globalN / total;
            double c21 = uncov * this.globalP / total;
            double c22 = uncov * this.globalN / total;
            score = cov > 0.0 && uncov > 0.0 ? (p - c11) * (p - c11) / c11 + (n - c12) * (n - c12) / c12 + (q - c21) * (q - c21) / c21 + (r - c22) * (r - c22) / c22 : 0.0;
        } else {
            score = Double.NaN;
            this.logWarning("Found unknown utility function: " + this.utilityFunction);
        }
        return score;
    }

    private double entropyLog2(double p) {
        if (Double.isNaN(p) || p == 0.0) {
            return 0.0;
        }
        return -p * Math.log(p) / Math.log(2.0);
    }

    protected double[] computePriors(ExampleSet exampleSet) {
        Attribute weightAttr = exampleSet.getAttributes().getWeight();
        double p = 0.0;
        double n = 0.0;
        for (Example example : exampleSet) {
            double weight;
            double d = weight = weightAttr == null ? 1.0 : example.getValue(weightAttr);
            if (example.getLabel() == (double)this.posIndex) {
                p += weight;
                continue;
            }
            n += weight;
        }
        return new double[]{p, n};
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeCategory(PARAMETER_UTILITY_FUNCTION, "The function to be optimized by the rule.", UTILITY_FUNCTION_LIST, 0));
        return types;
    }

    private static class ScoreNaNInfo {
        public double score;
        public boolean includeNaNs;
        public boolean predicted;

        ScoreNaNInfo(double score, boolean includeNaNs, boolean predicted) {
            this.score = score;
            this.includeNaNs = includeNaNs;
            this.predicted = predicted;
        }

        public ScoreNaNInfo max(ScoreNaNInfo other) {
            if (this.score >= other.score) {
                return this;
            }
            return other;
        }
    }

    public static class DecisionStumpModel
    extends SimplePredictionModel {
        private static final long serialVersionUID = -261158567126510415L;
        private final Attribute testAttribute;
        private final double testValue;
        private final boolean prediction;
        private boolean includeNaNs;
        private final boolean numerical;

        public DecisionStumpModel(Attribute attribute, double testValue, ExampleSet exampleSet, boolean prediction, boolean includeNaNs) {
            super(exampleSet);
            this.prediction = prediction;
            this.includeNaNs = includeNaNs;
            this.testAttribute = attribute;
            this.testValue = testValue;
            this.numerical = this.testAttribute == null || !this.testAttribute.isNominal();
        }

        @Override
        public double predict(Example example) {
            boolean evaluatesToTrue;
            if (this.testAttribute == null) {
                evaluatesToTrue = true;
            } else {
                double exampleValue = example.getValue(this.testAttribute);
                if (Double.isNaN(exampleValue)) {
                    evaluatesToTrue = this.includeNaNs;
                } else if (this.numerical) {
                    evaluatesToTrue = example.getValue(this.testAttribute) <= this.testValue;
                } else {
                    boolean bl = evaluatesToTrue = example.getValue(this.testAttribute) == this.testValue;
                }
            }
            if (evaluatesToTrue == this.prediction) {
                return this.getLabel().getMapping().getPositiveIndex();
            }
            return this.getLabel().getMapping().getNegativeIndex();
        }

        @Override
        public String toString() {
            String posIndexS = this.getLabel().getMapping().getPositiveString();
            String negIndexS = this.getLabel().getMapping().getNegativeString();
            StringBuffer result = new StringBuffer(super.toString());
            result.append(Tools.getLineSeparator() + " (" + this.getLabel().getName() + "=");
            result.append((this.prediction ? posIndexS : negIndexS) + ") <-- ");
            result.append(this.testAttribute != null ? this.testAttribute.getName() + (this.numerical ? " <= " + this.testValue : " = " + this.testAttribute.getMapping().mapIndex((int)this.testValue)) : "");
            result.append(Tools.getLineSeparator() + " unknown: predict '" + (this.includeNaNs ? posIndexS : negIndexS) + "'.");
            return result.toString();
        }
    }
}

