/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.bayes;

import Jama.Matrix;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.SimplePredictionModel;
import com.rapidminer.tools.Tools;

public class DiscriminantModel
extends SimplePredictionModel {
    private static final long serialVersionUID = 3793343069512113817L;
    private double alpha;
    private String[] labels;
    private Matrix[] meanVectors;
    private Matrix[] inverseCovariances;
    private double[] aprioriProbabilities;
    private double[] constClassValues;

    public DiscriminantModel(ExampleSet exampleSet, String[] labels, Matrix[] meanVectors, Matrix[] inverseCovariances, double[] aprioriProbabilities, double alpha) {
        super(exampleSet);
        this.alpha = alpha;
        this.labels = labels;
        this.meanVectors = meanVectors;
        this.inverseCovariances = inverseCovariances;
        this.aprioriProbabilities = aprioriProbabilities;
        this.constClassValues = new double[labels.length];
        for (int i = 0; i < labels.length; ++i) {
            this.constClassValues[i] = -0.5 * meanVectors[i].times(inverseCovariances[i]).times(meanVectors[i].transpose()).get(0, 0) + Math.log(aprioriProbabilities[i]);
        }
    }

    @Override
    public double predict(Example example) throws OperatorException {
        int numberOfAttributes = this.meanVectors[0].getColumnDimension();
        double[] vector = new double[numberOfAttributes];
        int i = 0;
        for (Attribute attribute : example.getAttributes()) {
            if (!attribute.isNumerical()) continue;
            vector[i] = example.getValue(attribute);
            ++i;
        }
        Matrix xVector = new Matrix(vector, 1);
        double[] labelFunction = new double[this.labels.length];
        for (int labelIndex = 0; labelIndex < this.labels.length; ++labelIndex) {
            labelFunction[labelIndex] = xVector.times(this.inverseCovariances[labelIndex]).times(this.meanVectors[labelIndex].transpose()).get(0, 0) + this.constClassValues[labelIndex];
        }
        double maximalValue = Double.NEGATIVE_INFINITY;
        int bestValue = 0;
        for (int labelIndex = 0; labelIndex < this.labels.length; ++labelIndex) {
            if (!(labelFunction[labelIndex] >= maximalValue)) continue;
            bestValue = labelIndex;
            maximalValue = labelFunction[labelIndex];
        }
        return bestValue;
    }

    @Override
    public String getName() {
        if (this.alpha == 0.0) {
            return "Quadratic Discriminant Model";
        }
        if (this.alpha == 1.0) {
            return "Linear Discriminant Model";
        }
        return "Regularized Discriminant Model";
    }

    @Override
    public String toString() {
        StringBuffer buffer = new StringBuffer();
        buffer.append("Apriori probabilities:\n");
        for (int i = 0; i < this.labels.length; ++i) {
            buffer.append(this.labels[i] + "\t");
            buffer.append(Tools.formatNumber(this.aprioriProbabilities[i], 4) + "\n");
        }
        return buffer.toString();
    }
}

