/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.bayes;

import Jama.Matrix;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.learner.bayes.DiscriminantModel;
import com.rapidminer.operator.learner.bayes.LinearDiscriminantAnalysis;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.math.MathFunctions;
import com.rapidminer.tools.math.matrix.CovarianceMatrix;
import java.util.List;

public class RegularizedDiscriminantAnalysis
extends LinearDiscriminantAnalysis {
    public static final String PARAMETER_ALPHA = "alpha";

    public RegularizedDiscriminantAnalysis(OperatorDescription description) {
        super(description);
    }

    @Override
    protected Matrix[] getInverseCovarianceMatrices(ExampleSet exampleSet, String[] labels) throws UndefinedParameterError {
        double alpha = this.getParameterAsDouble(PARAMETER_ALPHA);
        Matrix[] globalInverseCovariances = super.getInverseCovarianceMatrices(exampleSet, labels);
        Matrix[] classInverseCovariances = new Matrix[labels.length];
        Attribute labelAttribute = exampleSet.getAttributes().getLabel();
        SplittedExampleSet labelSet = SplittedExampleSet.splitByAttribute(exampleSet, exampleSet.getAttributes().getLabel());
        int labelIndex = 0;
        for (String label : labels) {
            Matrix inverse;
            for (int i = 0; i < labels.length; ++i) {
                labelSet.selectSingleSubset(i);
                if (labelSet.getExample(0).getNominalValue(labelAttribute).equals(label)) break;
            }
            classInverseCovariances[labelIndex] = inverse = MathFunctions.invertMatrix(CovarianceMatrix.getCovarianceMatrix(labelSet));
            ++labelIndex;
        }
        Matrix[] regularizedMatrices = new Matrix[classInverseCovariances.length];
        for (int i = 0; i < labels.length; ++i) {
            regularizedMatrices[i] = globalInverseCovariances[i].times(alpha).plus(classInverseCovariances[i].times(1.0 - alpha));
        }
        return classInverseCovariances;
    }

    @Override
    protected DiscriminantModel getModel(ExampleSet exampleSet, String[] labels, Matrix[] meanVectors, Matrix[] inverseCovariances, double[] aprioriProbabilities) throws UndefinedParameterError {
        return new DiscriminantModel(exampleSet, labels, meanVectors, inverseCovariances, aprioriProbabilities, this.getParameterAsDouble(PARAMETER_ALPHA));
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> list = super.getParameterTypes();
        list.add(new ParameterTypeDouble(PARAMETER_ALPHA, "This is the strength of regularization. 1: Only global covariance is used, 0: Only per class covariance is used.", 0.0, 1.0, 0.5, false));
        return list;
    }
}

