/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.functions.kernel.rvm;

import Jama.Matrix;
import com.rapidminer.operator.learner.functions.kernel.rvm.ClassificationProblem;
import com.rapidminer.operator.learner.functions.kernel.rvm.Model;
import com.rapidminer.operator.learner.functions.kernel.rvm.Parameter;
import com.rapidminer.operator.learner.functions.kernel.rvm.RVMBase;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelBasisFunction;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelRadial;
import com.rapidminer.operator.learner.functions.kernel.rvm.util.SECholeskyDecomposition;
import java.util.Iterator;
import java.util.LinkedList;

public class RVMClassification
extends RVMBase {
    public RVMClassification(ClassificationProblem problem, Parameter parameter) {
        super(problem, parameter);
    }

    @Override
    public Model learn() {
        int i;
        int j;
        ClassificationProblem problem = (ClassificationProblem)this.problem;
        int numExamples = problem.getProblemSize();
        int numBases = numExamples + 1;
        int prune_point = 50;
        prune_point = this.parameter.maxIterations * prune_point / 100;
        double[][] x = problem.getInputVectors();
        KernelBasisFunction[] kernels = problem.getKernels();
        double[][] PHI = new double[numExamples][numBases];
        for (j = 0; j < numBases - 1; ++j) {
            for (i = 0; i < numExamples; ++i) {
                PHI[i][j + 1] = kernels[j + 1].eval(x[i]);
            }
        }
        for (i = 0; i < numExamples; ++i) {
            PHI[i][0] = 1.0;
        }
        double[] alphas = new double[numBases];
        for (j = 0; j < numBases; ++j) {
            alphas[j] = this.parameter.initAlpha;
        }
        Matrix matrixPHI = new Matrix(PHI);
        Matrix vectorAlpha = new Matrix(alphas, numBases);
        Matrix vectorWeights = new Matrix(numBases, 1, 0.0);
        Matrix prunedVectorWeights = null;
        Matrix matrixU = null;
        Matrix matrixUInv = null;
        LinkedList<Integer> unprunedIndicesList = null;
        int[] unprunedIndicesArray = null;
        for (i = 1; i <= this.parameter.maxIterations; ++i) {
            unprunedIndicesList = new LinkedList<Integer>();
            for (j = 0; j < numBases; ++j) {
                if (vectorAlpha.get(j, 0) >= this.parameter.alpha_max) continue;
                unprunedIndicesList.add(j);
            }
            unprunedIndicesArray = new int[unprunedIndicesList.size()];
            Iterator iter = unprunedIndicesList.iterator();
            for (j = 0; j < unprunedIndicesList.size(); ++j) {
                unprunedIndicesArray[j] = (Integer)iter.next();
            }
            Matrix prunedMatrixPHI = matrixPHI.getMatrix(0, matrixPHI.getRowDimension() - 1, unprunedIndicesArray);
            Matrix prunedVectorAlpha = vectorAlpha.getMatrix(unprunedIndicesArray, 0, 0);
            prunedVectorWeights = vectorWeights.getMatrix(unprunedIndicesArray, 0, 0);
            double minGradientChange = 1.0E-6;
            double minLambda = Math.pow(2.0, -8.0);
            Matrix matrixAlphaDiag = new Matrix(prunedVectorAlpha.getRowDimension(), prunedVectorAlpha.getRowDimension(), 0.0);
            for (j = 0; j < prunedVectorAlpha.getRowDimension(); ++j) {
                matrixAlphaDiag.set(j, j, prunedVectorAlpha.get(j, 0));
            }
            Matrix vectorY = prunedMatrixPHI.times(prunedVectorWeights);
            for (int k = 0; k < vectorY.getRowDimension(); ++k) {
                vectorY.set(k, 0, this.sigmoid(vectorY.get(k, 0)));
            }
            double dataTerm = 0.0;
            int[] t = problem.getTargetVectors();
            for (int k = 0; k < t.length; ++k) {
                if (t[k] == 1) {
                    dataTerm -= Math.log(vectorY.get(k, 0));
                    continue;
                }
                dataTerm -= Math.log(1.0 - vectorY.get(k, 0));
            }
            double penaltyTerm = 0.0;
            for (int k = 0; k < prunedVectorAlpha.getRowDimension(); ++k) {
                penaltyTerm += prunedVectorAlpha.get(k, 0) * prunedVectorWeights.get(k, 0) * prunedVectorWeights.get(k, 0);
            }
            double error = (dataTerm + penaltyTerm / 2.0) / (double)problem.getProblemSize();
            block11: for (j = 0; j < 25; ++j) {
                Matrix matrixIRLSWeights = new Matrix(prunedMatrixPHI.getRowDimension(), prunedMatrixPHI.getRowDimension(), 0.0);
                for (int k = 0; k < matrixIRLSWeights.getRowDimension(); ++k) {
                    matrixIRLSWeights.set(k, k, vectorY.get(k, 0) * (1.0 - vectorY.get(k, 0)));
                }
                Matrix matrixHessian = prunedMatrixPHI.transpose().times(matrixIRLSWeights).times(prunedMatrixPHI);
                matrixHessian.plusEquals(matrixAlphaDiag);
                Matrix vectorE = new Matrix(vectorY.getRowDimension(), 1, 0.0);
                for (int k = 0; k < vectorY.getRowDimension(); ++k) {
                    vectorE.set(k, 0, (double)t[k] - vectorY.get(k, 0));
                }
                Matrix vectorPenalty = (Matrix)prunedVectorAlpha.clone();
                for (int k = 0; k < vectorPenalty.getRowDimension(); ++k) {
                    vectorPenalty.set(k, 0, vectorPenalty.get(k, 0) * prunedVectorWeights.get(k, 0));
                }
                Matrix vectorGradient = prunedMatrixPHI.transpose().times(vectorE).minus(vectorPenalty);
                SECholeskyDecomposition CD = new SECholeskyDecomposition(matrixHessian.getArray());
                matrixU = CD.getPTR().times(CD.getL());
                matrixUInv = matrixU.inverse();
                if (j >= 2 && vectorGradient.normF() / (double)prunedVectorWeights.getRowDimension() < minGradientChange) break;
                Matrix vectorDeltaWeights = matrixUInv.transpose().times(matrixUInv.times(vectorGradient));
                for (double lambda = 1.0; lambda > minLambda; lambda /= 2.0) {
                    int k;
                    Matrix vectorNewWeights = ((Matrix)prunedVectorWeights.clone()).plus(vectorDeltaWeights.times(lambda));
                    vectorY = prunedMatrixPHI.times(vectorNewWeights);
                    for (k = 0; k < vectorY.getRowDimension(); ++k) {
                        vectorY.set(k, 0, this.sigmoid(vectorY.get(k, 0)));
                    }
                    dataTerm = 0.0;
                    for (k = 0; k < t.length; ++k) {
                        if (t[k] == 1) {
                            dataTerm -= Math.log(vectorY.get(k, 0));
                            continue;
                        }
                        dataTerm -= Math.log(1.0 - vectorY.get(k, 0));
                    }
                    penaltyTerm = 0.0;
                    for (k = 0; k < prunedVectorAlpha.getRowDimension(); ++k) {
                        penaltyTerm += prunedVectorAlpha.get(k, 0) * vectorNewWeights.get(k, 0) * vectorNewWeights.get(k, 0);
                    }
                    double error_new = (dataTerm + penaltyTerm / 2.0) / (double)problem.getProblemSize();
                    if (error_new > error) {
                        continue;
                    }
                    prunedVectorWeights = vectorNewWeights;
                    continue block11;
                }
            }
            double[] diagSIGMA = new double[matrixUInv.getRowDimension()];
            for (j = 0; j < diagSIGMA.length; ++j) {
                double value = 0.0;
                for (int k = 0; k < diagSIGMA.length; ++k) {
                    value += matrixUInv.get(k, j) * matrixUInv.get(k, j);
                }
                diagSIGMA[j] = value;
            }
            double[] gammas = new double[diagSIGMA.length];
            for (j = 0; j < gammas.length; ++j) {
                gammas[j] = 1.0 - prunedVectorAlpha.get(j, 0) * diagSIGMA[j];
            }
            double[] logAlphas = new double[prunedVectorAlpha.getRowDimension()];
            for (j = 0; j < logAlphas.length; ++j) {
                logAlphas[j] = Math.log(prunedVectorAlpha.get(j, 0));
            }
            for (j = 0; j < prunedVectorAlpha.getRowDimension(); ++j) {
                double newAlpha = gammas[j] / (prunedVectorWeights.get(j, 0) * prunedVectorWeights.get(j, 0));
                prunedVectorAlpha.set(j, 0, newAlpha);
            }
            double maxLogAlphaChange = 0.0;
            for (j = 0; j < logAlphas.length; ++j) {
                double change = Math.abs(logAlphas[j] - Math.log(prunedVectorAlpha.get(j, 0)));
                if (!(change > maxLogAlphaChange)) continue;
                maxLogAlphaChange = change;
            }
            if (maxLogAlphaChange < this.parameter.min_delta_log_alpha) break;
            for (j = 0; j < prunedVectorAlpha.getRowDimension(); ++j) {
                vectorAlpha.set(unprunedIndicesArray[j], 0, prunedVectorAlpha.get(j, 0));
            }
        }
        double[] finalWeights = new double[unprunedIndicesArray.length];
        KernelBasisFunction[] finalKernels = new KernelBasisFunction[unprunedIndicesArray.length];
        boolean bias = false;
        for (j = 0; j < unprunedIndicesArray.length; ++j) {
            finalWeights[j] = prunedVectorWeights.get(j, 0);
            if (unprunedIndicesArray[j] == 0) {
                bias = true;
                finalKernels[j] = new KernelBasisFunction(new KernelRadial());
                continue;
            }
            finalKernels[j] = kernels[unprunedIndicesArray[j]];
        }
        Model model = new Model(finalWeights, finalKernels, bias, false);
        return model;
    }

    public double sigmoid(double x) {
        return 1.0 / (1.0 + Math.exp(-x));
    }

    public String toString() {
        return "Classification-RVM";
    }
}

