/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.bayes;

import Jama.Matrix;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.example.table.NominalMapping;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.bayes.DiscriminantModel;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.math.MathFunctions;
import com.rapidminer.tools.math.matrix.CovarianceMatrix;

public class LinearDiscriminantAnalysis
extends AbstractLearner {
    public LinearDiscriminantAnalysis(OperatorDescription description) {
        super(description);
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        int numberOfNumericalAttributes = 0;
        for (Attribute attribute : exampleSet.getAttributes()) {
            if (!attribute.isNumerical()) continue;
            ++numberOfNumericalAttributes;
        }
        NominalMapping labelMapping = exampleSet.getAttributes().getLabel().getMapping();
        String[] labelValues = new String[labelMapping.size()];
        for (int i = 0; i < labelMapping.size(); ++i) {
            labelValues[i] = labelMapping.mapIndex(i);
        }
        Matrix[] meanVectors = this.getMeanVectors(exampleSet, numberOfNumericalAttributes, labelValues);
        Matrix[] inverseCovariance = this.getInverseCovarianceMatrices(exampleSet, labelValues);
        return this.getModel(exampleSet, labelValues, meanVectors, inverseCovariance, this.getAprioriProbabilities(exampleSet, labelValues));
    }

    protected DiscriminantModel getModel(ExampleSet exampleSet, String[] labels, Matrix[] meanVectors, Matrix[] inverseCovariances, double[] aprioriProbabilities) throws UndefinedParameterError {
        return new DiscriminantModel(exampleSet, labels, meanVectors, inverseCovariances, aprioriProbabilities, 0.0);
    }

    private double[] getAprioriProbabilities(ExampleSet exampleSet, String[] labels) {
        double[] aprioriProbabilites = new double[labels.length];
        double totalSize = exampleSet.size();
        Attribute labelAttribute = exampleSet.getAttributes().getLabel();
        SplittedExampleSet labelSet = SplittedExampleSet.splitByAttribute(exampleSet, exampleSet.getAttributes().getLabel());
        int labelIndex = 0;
        for (String label : labels) {
            for (int i = 0; i < labels.length; ++i) {
                labelSet.selectSingleSubset(i);
                if (labelSet.getExample(0).getNominalValue(labelAttribute).equals(label)) break;
            }
            aprioriProbabilites[labelIndex] = (double)labelSet.size() / totalSize;
            ++labelIndex;
        }
        return aprioriProbabilites;
    }

    protected Matrix[] getMeanVectors(ExampleSet exampleSet, int numberOfAttributes, String[] labels) throws UserError {
        Matrix[] classMeanVectors = new Matrix[labels.length];
        Attribute labelAttribute = exampleSet.getAttributes().getLabel();
        SplittedExampleSet labelSet = SplittedExampleSet.splitByAttribute(exampleSet, exampleSet.getAttributes().getLabel());
        if (labelSet.getNumberOfSubsets() != labels.length) {
            throw new UserError((Operator)this, 118, labelAttribute, labelSet.getNumberOfSubsets(), 2);
        }
        int labelIndex = 0;
        for (String label : labels) {
            for (int i = 0; i < labels.length; ++i) {
                labelSet.selectSingleSubset(i);
                if (labelSet.getExample(0).getNominalValue(labelAttribute).equals(label)) break;
            }
            labelSet.recalculateAllAttributeStatistics();
            double[] meanValues = new double[numberOfAttributes];
            int i = 0;
            for (Attribute attribute : labelSet.getAttributes()) {
                if (attribute.isNumerical()) {
                    meanValues[i] = labelSet.getStatistics(attribute, "average");
                }
                ++i;
            }
            classMeanVectors[labelIndex] = new Matrix(meanValues, 1);
            ++labelIndex;
        }
        return classMeanVectors;
    }

    protected Matrix[] getInverseCovarianceMatrices(ExampleSet exampleSet, String[] labels) throws UndefinedParameterError {
        Matrix[] classInverseCovariances = new Matrix[labels.length];
        Matrix inverse = MathFunctions.invertMatrix(CovarianceMatrix.getCovarianceMatrix(exampleSet));
        for (int i = 0; i < labels.length; ++i) {
            classInverseCovariances[i] = inverse;
        }
        return classInverseCovariances;
    }

    @Override
    public Class<? extends PredictionModel> getModelClass() {
        return DiscriminantModel.class;
    }

    @Override
    public boolean supportsCapability(OperatorCapability capability) {
        if (capability.equals((Object)OperatorCapability.NUMERICAL_ATTRIBUTES)) {
            return true;
        }
        if (capability.equals((Object)OperatorCapability.BINOMINAL_LABEL)) {
            return true;
        }
        return capability.equals((Object)OperatorCapability.POLYNOMINAL_LABEL);
    }
}

