/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.performance;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.performance.MeasuredPerformance;
import edu.udo.cs.yale.tools.math.Averagable;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Collection;
import java.util.Iterator;

public class MultiClassificationPerformance
extends MeasuredPerformance {
    public static final int UNDEFINED = -1;
    public static final int ACCURACY = 0;
    public static final int ERROR = 1;
    public static final int KAPPA = 2;
    public static final String[] NAME = new String[]{"accuracy", "classification_error", "kappa"};
    public static final String[] DESCRIPTION = new String[]{"Relative number of correctly classified examples", "Relative number of misclassified examples", "The kappa statistics for the classification"};
    private int[][] counter;
    private String[] classNames;
    private int type = 0;

    public MultiClassificationPerformance() {
        this(-1);
    }

    public MultiClassificationPerformance(int type) {
        this.type = type;
    }

    public static MultiClassificationPerformance newInstance(String name) {
        for (int i = 0; i < NAME.length; ++i) {
            if (!NAME[i].equals(name)) continue;
            return new MultiClassificationPerformance(i);
        }
        return null;
    }

    public void writeCriterionData(PrintWriter out) throws IOException {
        int i;
        out.println("type: " + this.type);
        out.println("number_of_classes: " + this.classNames.length);
        for (i = 0; i < this.classNames.length; ++i) {
            out.println("classname_" + i + ": " + this.classNames[i]);
        }
        for (i = 0; i < this.classNames.length; ++i) {
            for (int j = 0; j < this.classNames.length; ++j) {
                out.println("counter_" + i + "_" + j + ": " + this.counter[i][j]);
            }
        }
    }

    public void readCriterionData(BufferedReader in) throws IOException {
        int i;
        String line = in.readLine();
        line = line.substring(line.indexOf(":") + 1).trim();
        this.type = Integer.parseInt(line);
        line = in.readLine();
        line = line.substring(line.indexOf(":") + 1).trim();
        int numberOfClasses = Integer.parseInt(line);
        this.classNames = new String[numberOfClasses];
        for (i = 0; i < this.classNames.length; ++i) {
            line = in.readLine();
            this.classNames[i] = line.substring(line.indexOf(":") + 1).trim();
        }
        this.counter = new int[numberOfClasses][numberOfClasses];
        for (i = 0; i < this.classNames.length; ++i) {
            for (int j = 0; j < this.classNames.length; ++j) {
                line = in.readLine();
                line = line.substring(line.indexOf(":") + 1).trim();
                this.counter[i][j] = Integer.parseInt(line);
            }
        }
    }

    public int getExampleCount() {
        int total = 0;
        for (int i = 0; i < this.counter.length; ++i) {
            for (int j = 0; j < this.counter[i].length; ++j) {
                total += this.counter[i][j];
            }
        }
        return total;
    }

    public void startCounting(ExampleSet eSet) {
        Attribute label = eSet.getLabel();
        Collection values = label.getValues();
        this.counter = new int[values.size()][values.size()];
        this.classNames = new String[values.size()];
        Iterator i = values.iterator();
        int n = 0;
        while (i.hasNext()) {
            this.classNames[n++] = (String)i.next();
        }
    }

    public void countExample(Example example) {
        int label = (int)example.getLabel();
        int plabel = (int)example.getPredictedLabel();
        int[] nArray = this.counter[label];
        int n = plabel;
        nArray[n] = nArray[n] + 1;
    }

    public double getValue() {
        int diagonal = 0;
        int total = 0;
        for (int i = 0; i < this.counter.length; ++i) {
            diagonal += this.counter[i][i];
            for (int j = 0; j < this.counter[i].length; ++j) {
                total += this.counter[i][j];
            }
        }
        if (total == 0) {
            return Double.NaN;
        }
        double accuracy = (double)diagonal / (double)total;
        if (this.type == 0) {
            return accuracy;
        }
        if (this.type == 1) {
            return 1.0 - accuracy;
        }
        if (this.type == 2) {
            double pa = accuracy;
            double pe = 0.0;
            for (int i = 0; i < this.counter.length; ++i) {
                double row = 0.0;
                double column = 0.0;
                for (int j = 0; j < this.counter[i].length; ++j) {
                    row += (double)this.counter[i][j];
                    column += (double)this.counter[j][i];
                }
                pe += row * column / Math.pow(total, this.counter.length);
            }
            return (pa - pe) / (1.0 - pe);
        }
        throw new RuntimeException("Unknown type " + this.type + " for multi class performance criterion!");
    }

    public boolean formatPercent() {
        return this.type != 2;
    }

    public double getVariance() {
        return Double.NaN;
    }

    public String getName() {
        return NAME[this.type];
    }

    public String getDescription() {
        return DESCRIPTION[this.type];
    }

    public double getFitness() {
        if (this.type == 1) {
            return 1.0 - this.getValue();
        }
        return this.getValue();
    }

    public double getMaxFitness() {
        return 1.0;
    }

    protected void cloneAveragable(Averagable newPC) {
        int i;
        MultiClassificationPerformance newMulti = (MultiClassificationPerformance)newPC;
        this.type = newMulti.type;
        this.classNames = new String[newMulti.classNames.length];
        for (i = 0; i < this.classNames.length; ++i) {
            this.classNames[i] = newMulti.classNames[i];
        }
        this.counter = new int[newMulti.counter.length][newMulti.counter.length];
        for (i = 0; i < this.counter.length; ++i) {
            for (int j = 0; j < this.counter[i].length; ++j) {
                this.counter[i][j] = newMulti.counter[i][j];
            }
        }
    }

    public void buildSingleAverage(Averagable performance) {
        MultiClassificationPerformance other = (MultiClassificationPerformance)performance;
        for (int i = 0; i < this.counter.length; ++i) {
            for (int j = 0; j < this.counter[i].length; ++j) {
                int[] nArray = this.counter[i];
                int n = j;
                nArray[n] = nArray[n] + other.counter[i][j];
            }
        }
    }

    public String toString() {
        int i;
        StringBuffer result = new StringBuffer(super.toString());
        result.append("\nConfusionMatrix:\nTrue:");
        for (i = 0; i < this.counter.length; ++i) {
            result.append("\t" + this.classNames[i]);
        }
        for (i = 0; i < this.counter.length; ++i) {
            result.append("\n" + this.classNames[i] + ":");
            for (int j = 0; j < this.counter[i].length; ++j) {
                result.append("\t" + this.counter[j][i]);
            }
        }
        result.append("\nper class:");
        for (i = 0; i < this.counter.length; ++i) {
            int total = 0;
            for (int j = 0; j < this.counter[i].length; ++j) {
                total += this.counter[j][i];
            }
            result.append("\t" + this.formatValue((double)this.counter[i][i] / (double)total));
        }
        return result.toString();
    }

    public String toHTML() {
        int i;
        StringBuffer result = new StringBuffer(super.toString());
        result.append("<table bgcolor=\"#E3D8C3\" border=\"1\"><tr bgcolor=\"#ccccff\"><td></td>");
        for (i = 0; i < this.counter.length; ++i) {
            result.append("<td><b>true " + this.classNames[i] + "</b></td>");
        }
        result.append("</tr>");
        for (i = 0; i < this.counter.length; ++i) {
            result.append("<tr><td bgcolor=\"#ccccff\"><b>pred. " + this.classNames[i] + "</b></td>");
            for (int j = 0; j < this.counter[i].length; ++j) {
                result.append("<td>" + this.counter[j][i] + "</td>");
            }
            result.append("</tr>");
        }
        if (this.type != 2) {
            result.append("<tr bgcolor=\"#ccccff\"><td><b>per class:</b></td>");
            for (i = 0; i < this.counter.length; ++i) {
                int total = 0;
                for (int j = 0; j < this.counter[i].length; ++j) {
                    total += this.counter[i][j];
                }
                if (this.type == 0) {
                    result.append("<td>" + this.formatValue((double)this.counter[i][i] / (double)total) + "</td>");
                    continue;
                }
                result.append("<td>" + this.formatValue(1.0 - (double)this.counter[i][i] / (double)total) + "</td>");
            }
            result.append("</tr>");
        }
        result.append("</table>");
        return result.toString();
    }
}

