/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.performance;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.performance.ClassWeightedPerformance;
import com.rapidminer.operator.performance.MeasuredPerformance;
import com.rapidminer.tools.Tools;
import com.rapidminer.tools.math.Averagable;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

public class WeightedMultiClassPerformance
extends MeasuredPerformance
implements ClassWeightedPerformance {
    private static final long serialVersionUID = 8734250559680229116L;
    public static final int UNDEFINED = -1;
    public static final int WEIGHTED_RECALL = 0;
    public static final int WEIGHTED_PRECISION = 1;
    public static final String[] NAMES = new String[]{"weighted_mean_recall", "weighted_mean_precision"};
    public static final String[] DESCRIPTIONS = new String[]{"The weighted mean of all per class recall measurements.", "The weighted mean of all per class precision measurements."};
    private double[][] counter;
    private String[] classNames;
    private Map<String, Integer> classNameMap = new HashMap<String, Integer>();
    private int type = 0;
    private double[] classWeights;
    private double weightSum;
    private Attribute labelAttribute;
    private Attribute predictedLabelAttribute;
    private Attribute weightAttribute;

    public WeightedMultiClassPerformance() {
        this(-1);
    }

    public WeightedMultiClassPerformance(int type) {
        this.type = type;
    }

    public WeightedMultiClassPerformance(WeightedMultiClassPerformance m) {
        super(m);
        int i;
        this.type = m.type;
        this.classNames = new String[m.classNames.length];
        for (i = 0; i < this.classNames.length; ++i) {
            this.classNames[i] = m.classNames[i];
            this.classNameMap.put(this.classNames[i], i);
        }
        this.counter = new double[m.counter.length][m.counter.length];
        for (i = 0; i < this.counter.length; ++i) {
            for (int j = 0; j < this.counter[i].length; ++j) {
                this.counter[i][j] = m.counter[i][j];
            }
        }
        this.labelAttribute = (Attribute)m.labelAttribute.clone();
        this.predictedLabelAttribute = (Attribute)m.predictedLabelAttribute.clone();
        if (m.weightAttribute != null) {
            this.weightAttribute = (Attribute)m.weightAttribute.clone();
        }
    }

    public static WeightedMultiClassPerformance newInstance(String name) {
        for (int i = 0; i < NAMES.length; ++i) {
            if (!NAMES[i].equals(name)) continue;
            return new WeightedMultiClassPerformance(i);
        }
        return null;
    }

    @Override
    public void setWeights(double[] weights) {
        this.weightSum = 0.0;
        for (double w : this.classWeights = weights) {
            this.weightSum += w;
        }
    }

    @Override
    public void startCounting(ExampleSet eSet, boolean useExampleWeights) throws OperatorException {
        super.startCounting(eSet, useExampleWeights);
        this.labelAttribute = eSet.getAttributes().getLabel();
        if (!this.labelAttribute.isNominal()) {
            throw new UserError(null, 101, "calculation of classification performance criteria", this.labelAttribute.getName());
        }
        this.predictedLabelAttribute = eSet.getAttributes().getPredictedLabel();
        if (this.predictedLabelAttribute == null || !this.predictedLabelAttribute.isNominal()) {
            throw new UserError(null, 101, "calculation of classification performance criteria", "predicted label attribute");
        }
        if (useExampleWeights) {
            this.weightAttribute = eSet.getAttributes().getWeight();
        }
        List<String> values = this.labelAttribute.getMapping().getValues();
        this.counter = new double[values.size()][values.size()];
        this.classNames = new String[values.size()];
        Iterator i = values.iterator();
        int n = 0;
        while (i.hasNext()) {
            this.classNames[n] = (String)i.next();
            this.classNameMap.put(this.classNames[n], n);
            ++n;
        }
    }

    @Override
    public void countExample(Example example) {
        int label = this.classNameMap.get(example.getNominalValue(this.labelAttribute));
        int plabel = this.classNameMap.get(example.getNominalValue(this.predictedLabelAttribute));
        double weight = 1.0;
        if (this.weightAttribute != null) {
            weight = example.getValue(this.weightAttribute);
        }
        double[] dArray = this.counter[label];
        int n = plabel;
        dArray[n] = dArray[n] + weight;
    }

    @Override
    public double getExampleCount() {
        double total = 0.0;
        for (int i = 0; i < this.counter.length; ++i) {
            for (int j = 0; j < this.counter[i].length; ++j) {
                total += this.counter[i][j];
            }
        }
        return total;
    }

    @Override
    public double getMikroAverage() {
        switch (this.type) {
            case 0: {
                double[] columnSums = new double[this.classNames.length];
                for (int c = 0; c < columnSums.length; ++c) {
                    for (int r = 0; r < this.counter[c].length; ++r) {
                        int n = c;
                        columnSums[n] = columnSums[n] + this.counter[c][r];
                    }
                }
                double result = 0.0;
                for (int c = 0; c < columnSums.length; ++c) {
                    double r = this.counter[c][c] / columnSums[c];
                    result += this.classWeights[c] * (Double.isNaN(r) ? 0.0 : r);
                }
                return result /= this.weightSum;
            }
            case 1: {
                int r;
                double[] rowSums = new double[this.classNames.length];
                for (r = 0; r < this.counter.length; ++r) {
                    for (int c = 0; c < this.counter[r].length; ++c) {
                        int n = r;
                        rowSums[n] = rowSums[n] + this.counter[c][r];
                    }
                }
                double result = 0.0;
                for (r = 0; r < rowSums.length; ++r) {
                    double p = this.counter[r][r] / rowSums[r];
                    result += this.classWeights[r] * (Double.isNaN(p) ? 0.0 : p);
                }
                return result /= this.weightSum;
            }
        }
        throw new RuntimeException("Unknown type " + this.type + " for weighted multi class performance criterion!");
    }

    @Override
    public boolean formatPercent() {
        return true;
    }

    @Override
    public double getMikroVariance() {
        return Double.NaN;
    }

    @Override
    public String getName() {
        return NAMES[this.type];
    }

    @Override
    public String getDescription() {
        return DESCRIPTIONS[this.type];
    }

    @Override
    public double getFitness() {
        return this.getAverage();
    }

    @Override
    public double getMaxFitness() {
        return 1.0;
    }

    @Override
    public void buildSingleAverage(Averagable performance) {
        WeightedMultiClassPerformance other = (WeightedMultiClassPerformance)performance;
        for (int i = 0; i < this.counter.length; ++i) {
            for (int j = 0; j < this.counter[i].length; ++j) {
                double[] dArray = this.counter[i];
                int n = j;
                dArray[n] = dArray[n] + other.counter[i][j];
            }
        }
    }

    public String toWeightString() {
        StringBuffer result = new StringBuffer(super.toString());
        result.append(", weights: ");
        boolean first = true;
        for (double w : this.classWeights) {
            if (!first) {
                result.append(", ");
            }
            result.append(Tools.formatIntegerIfPossible(w));
            first = false;
        }
        return result.toString();
    }

    @Override
    public String toString() {
        int i;
        StringBuffer result = new StringBuffer(this.toWeightString() + "");
        result.append(Tools.getLineSeparator() + "ConfusionMatrix:" + Tools.getLineSeparator() + "True:");
        for (i = 0; i < this.counter.length; ++i) {
            result.append("\t" + this.classNames[i]);
        }
        for (i = 0; i < this.counter.length; ++i) {
            result.append(Tools.getLineSeparator() + this.classNames[i] + ":");
            for (int j = 0; j < this.counter[i].length; ++j) {
                result.append("\t" + Tools.formatIntegerIfPossible(this.counter[j][i]));
            }
        }
        return result.toString();
    }

    public String[] getClassNames() {
        return this.classNames;
    }

    public double[][] getCounter() {
        return this.counter;
    }
}

