/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.meta;

import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.learner.meta.ContingencyMatrix;
import edu.udo.cs.yale.tools.LogService;
import java.util.Iterator;

public class WeightedPerformanceMeasures {
    public static final double RULE_DOES_NOT_APPLY = Double.NaN;
    private double[] predictions;
    private double[] labels;
    private double[][] pred_label;
    private int[][] unweighted_num_pred_label;

    public WeightedPerformanceMeasures(ExampleSet exampleSet) throws OperatorException {
        int numberOfClasses = exampleSet.getAttributes().getLabel().getMapping().getValues().size();
        this.labels = new double[numberOfClasses];
        this.predictions = new double[numberOfClasses];
        this.pred_label = new double[this.predictions.length][this.labels.length];
        this.unweighted_num_pred_label = new int[this.predictions.length][this.labels.length];
        Iterator reader = exampleSet.iterator();
        double sumOfWeights = 0.0;
        while (reader.hasNext()) {
            Example exa = (Example)reader.next();
            double exaW = exa.getWeight();
            sumOfWeights += exaW;
            int eLabel = (int)exa.getLabel();
            int ePred = (int)exa.getPredictedLabel();
            if (ePred >= 0 && ePred < this.predictions.length && eLabel >= 0 && eLabel < this.labels.length) {
                int[] nArray = this.unweighted_num_pred_label[ePred];
                int n = eLabel;
                nArray[n] = nArray[n] + 1;
                int n2 = eLabel;
                this.labels[n2] = this.labels[n2] + exaW;
                int n3 = ePred;
                this.predictions[n3] = this.predictions[n3] + exaW;
                double[] dArray = this.pred_label[ePred];
                int n4 = eLabel;
                dArray[n4] = dArray[n4] + exaW;
                continue;
            }
            exa.setWeight(0.0);
            exa.setLabel(0.0);
            exa.setPredictedLabel(0.0);
            LogService.logMessage("WeightedPerformanceMeasures: Deleted example with illegal label or prediction (" + eLabel + ", " + ePred + ")!", 4);
        }
        if (sumOfWeights > 0.0) {
            int i = 0;
            while (i < this.predictions.length) {
                int n = i;
                this.predictions[n] = this.predictions[n] / sumOfWeights;
                int j = 0;
                while (j < this.labels.length) {
                    double[] dArray = this.pred_label[i];
                    int n5 = j++;
                    dArray[n5] = dArray[n5] / sumOfWeights;
                }
                ++i;
            }
            int j = 0;
            while (j < this.labels.length) {
                int n = j++;
                this.labels[n] = this.labels[n] / sumOfWeights;
            }
        } else {
            double defaultPredProb = 1.0 / (double)this.predictions.length;
            double defaultLabelProb = 1.0 / (double)this.labels.length;
            double defaultPredLabelProb = defaultPredProb * defaultLabelProb;
            int i = 0;
            while (i < this.predictions.length) {
                this.predictions[i] = defaultPredProb;
                int j = 0;
                while (j < this.labels.length) {
                    this.pred_label[i][j] = defaultPredLabelProb;
                    ++j;
                }
                ++i;
            }
            int j = 0;
            while (j < this.labels.length) {
                this.labels[j] = defaultLabelProb;
                ++j;
            }
        }
    }

    public int[] getCoveredExamplesNumForPred(int prediction) {
        int length = this.unweighted_num_pred_label.length;
        if (prediction >= 0 && prediction < length) {
            return this.unweighted_num_pred_label[prediction];
        }
        return new int[length];
    }

    public int getNumberOfLabels() {
        return this.labels.length;
    }

    public int getNumberOfPredictions() {
        return this.predictions.length;
    }

    public double getProbability(int label, int prediction) {
        return this.pred_label[prediction][label];
    }

    public double getProbabilityLabel(int label) {
        return this.labels[label];
    }

    public double getProbabilityPrediction(int premise) {
        return this.predictions[premise];
    }

    public double getLift(int label, int prediction) {
        double prLabel = this.getProbabilityLabel(label);
        double prPred = this.getProbabilityPrediction(prediction);
        double prJoint = this.getProbability(label, prediction);
        if (prPred == 0.0) {
            return Double.NaN;
        }
        if (prJoint == 0.0) {
            return 0.0;
        }
        if (prJoint == prPred) {
            return Double.POSITIVE_INFINITY;
        }
        double lift = prJoint / (prLabel * prPred);
        return lift;
    }

    public double[] getPnRatios(int prediction) {
        double[] lifts = new double[this.labels.length];
        int i = 0;
        while (i < lifts.length) {
            int yaleLabelIndex = i;
            double b = this.getLift(yaleLabelIndex, prediction);
            if (b == 0.0 || b == Double.POSITIVE_INFINITY) {
                lifts[i] = b;
            } else {
                double negLabel = 1.0 - this.getProbabilityLabel(yaleLabelIndex);
                double probPred = this.getProbabilityPrediction(prediction);
                double probPredLabel = this.getProbability(yaleLabelIndex, prediction);
                double negLabelPred = probPred - probPredLabel;
                double oppositeLift = negLabelPred / (negLabel * probPred);
                lifts[i] = b / oppositeLift;
            }
            ++i;
        }
        return lifts;
    }

    public double[][] createLiftRatioMatrix() {
        int numPredictions = this.getNumberOfPredictions();
        double[][] liftRatioMatrix = new double[numPredictions][];
        int i = 0;
        while (i < numPredictions) {
            liftRatioMatrix[i] = this.getPnRatios(i);
            ++i;
        }
        return liftRatioMatrix;
    }

    public double[] getLabelPriors() {
        double[] priors = new double[this.getNumberOfLabels()];
        int i = 0;
        while (i < priors.length) {
            priors[i] = this.getProbabilityLabel(i);
            ++i;
        }
        return priors;
    }

    public int getNumberOfNonEmptyClasses() {
        int nonEmpty = 0;
        int i = 0;
        while (i < this.getNumberOfLabels()) {
            if (this.getProbabilityLabel(i) > 0.0) {
                ++nonEmpty;
            }
            ++i;
        }
        return nonEmpty;
    }

    public ContingencyMatrix getContingencyMatrix() {
        if (this.pred_label.length == 0 || this.pred_label[0].length == 0) {
            return new ContingencyMatrix(new double[0][0]);
        }
        double[][] matrix = new double[this.pred_label[0].length][this.pred_label.length];
        int i = 0;
        while (i < matrix.length) {
            int j = 0;
            while (j < matrix[i].length) {
                double predLabelJi = this.pred_label[j][i];
                if (Double.isNaN(predLabelJi) || predLabelJi < 0.0 || predLabelJi > 1.0) {
                    LogService.logMessage("Found illegal value in contingency matrix!", 4);
                }
                matrix[i][j] = predLabelJi;
                ++j;
            }
            ++i;
        }
        return new ContingencyMatrix(matrix);
    }

    public static double reweightExamples(ExampleSet exampleSet, ContingencyMatrix cm, boolean allowMarginalSkews) throws OperatorException {
        Iterator reader = exampleSet.iterator();
        double totalWeight = 0.0;
        while (reader.hasNext()) {
            double newWeight;
            int predicted;
            Example example = (Example)reader.next();
            int label = (int)example.getLabel();
            double lift = cm.getLift(label, predicted = (int)example.getPredictedLabel());
            if (Double.isNaN(lift) || lift < 0.0) {
                LogService.logMessage("Applied rule with an illegal lift of " + lift + " during reweighting!", 4);
                continue;
            }
            if (lift == 0.0 || Double.isInfinite(lift)) {
                example.setWeight(0.0);
                continue;
            }
            double weight = example.getWeight();
            if (Double.isNaN(weight) || Double.isInfinite(weight) || weight < 0.0) {
                LogService.logMessage("Found illegal weight: " + weight, 4);
                newWeight = 0.0;
            } else {
                if (weight == 0.0) continue;
                if (allowMarginalSkews) {
                    double prec = cm.getPrecision(label, predicted);
                    double invPrec = 1.0 - prec;
                    double beta = invPrec / prec;
                    if (prec <= 0.0 || invPrec < 0.0 || Double.isInfinite(beta) || Double.isNaN(beta)) {
                        LogService.logMessage("Reweighting uses invalid value:Precision is " + prec + ", inverse precision is " + invPrec + ", beta is " + beta, 4);
                    }
                    newWeight = weight * Math.sqrt(beta);
                } else {
                    newWeight = weight / lift;
                }
            }
            example.setWeight(newWeight);
            totalWeight += newWeight;
        }
        return totalWeight;
    }
}

