/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.myRVM;

import Jama.Matrix;
import edu.udo.cs.myRVM.ClassificationProblem;
import edu.udo.cs.myRVM.Kernel.KernelBasisFunction;
import edu.udo.cs.myRVM.Kernel.KernelRadial;
import edu.udo.cs.myRVM.Model;
import edu.udo.cs.myRVM.Parameter;
import edu.udo.cs.myRVM.RVMBase;
import edu.udo.cs.myRVM.Util.SECholeskyDecomposition;
import edu.udo.cs.yale.tools.LogService;
import java.util.Iterator;
import java.util.LinkedList;

public class RVMClassification
extends RVMBase {
    public RVMClassification(ClassificationProblem problem, Parameter parameter) {
        super(problem, parameter);
    }

    public Model learn() {
        int i;
        ClassificationProblem problem = (ClassificationProblem)this.problem;
        int numExamples = problem.getProblemSize();
        int numBases = numExamples + 1;
        int prune_point = 50;
        int monIts = 1;
        prune_point = this.parameter.maxIterations * prune_point / 100;
        double[][] x = problem.getInputVectors();
        KernelBasisFunction[] kernels = problem.getKernels();
        double[][] PHI = new double[numExamples][numBases];
        int j = 0;
        while (j < numBases - 1) {
            i = 0;
            while (i < numExamples) {
                PHI[i][j + 1] = kernels[j + 1].eval(x[i]);
                ++i;
            }
            ++j;
        }
        i = 0;
        while (i < numExamples) {
            PHI[i][0] = 1.0;
            ++i;
        }
        double[] alphas = new double[numBases];
        j = 0;
        while (j < numBases) {
            alphas[j] = this.parameter.initAlpha;
            ++j;
        }
        Matrix matrixPHI = new Matrix(PHI);
        Matrix vectorAlpha = new Matrix(alphas, numBases);
        Matrix vectorWeights = new Matrix(numBases, 1, 0.0);
        Matrix prunedVectorWeights = null;
        Matrix matrixU = null;
        Matrix matrixUInv = null;
        LinkedList<Integer> unprunedIndicesList = null;
        int[] unprunedIndicesArray = null;
        i = 1;
        while (i <= this.parameter.maxIterations) {
            unprunedIndicesList = new LinkedList<Integer>();
            j = 0;
            while (j < numBases) {
                if (!(vectorAlpha.get(j, 0) >= this.parameter.alpha_max)) {
                    unprunedIndicesList.add(j);
                }
                ++j;
            }
            unprunedIndicesArray = new int[unprunedIndicesList.size()];
            Iterator iter = unprunedIndicesList.iterator();
            j = 0;
            while (j < unprunedIndicesList.size()) {
                unprunedIndicesArray[j] = (Integer)iter.next();
                ++j;
            }
            if (i > 100) {
                monIts = 5;
            }
            if (i > 1000) {
                monIts = 100;
            }
            if (i % monIts == 0) {
                LogService.logMessage("it: " + i + " ; bases: " + unprunedIndicesArray.length, 2);
            }
            Matrix prunedMatrixPHI = matrixPHI.getMatrix(0, matrixPHI.getRowDimension() - 1, unprunedIndicesArray);
            Matrix prunedVectorAlpha = vectorAlpha.getMatrix(unprunedIndicesArray, 0, 0);
            prunedVectorWeights = vectorWeights.getMatrix(unprunedIndicesArray, 0, 0);
            double minGradientChange = 1.0E-6;
            double minLambda = Math.pow(2.0, -8.0);
            Matrix matrixAlphaDiag = new Matrix(prunedVectorAlpha.getRowDimension(), prunedVectorAlpha.getRowDimension(), 0.0);
            j = 0;
            while (j < prunedVectorAlpha.getRowDimension()) {
                matrixAlphaDiag.set(j, j, prunedVectorAlpha.get(j, 0));
                ++j;
            }
            Matrix vectorY = prunedMatrixPHI.times(prunedVectorWeights);
            int k = 0;
            while (k < vectorY.getRowDimension()) {
                vectorY.set(k, 0, this.sigmoid(vectorY.get(k, 0)));
                ++k;
            }
            double dataTerm = 0.0;
            int[] t = problem.getTargetVectors();
            int k2 = 0;
            while (k2 < t.length) {
                dataTerm = t[k2] == 1 ? (dataTerm -= Math.log(vectorY.get(k2, 0))) : (dataTerm -= Math.log(1.0 - vectorY.get(k2, 0)));
                ++k2;
            }
            double penaltyTerm = 0.0;
            int k3 = 0;
            while (k3 < prunedVectorAlpha.getRowDimension()) {
                penaltyTerm += prunedVectorAlpha.get(k3, 0) * prunedVectorWeights.get(k3, 0) * prunedVectorWeights.get(k3, 0);
                ++k3;
            }
            double error = (dataTerm + penaltyTerm / 2.0) / (double)problem.getProblemSize();
            j = 0;
            while (j < 25) {
                Matrix matrixIRLSWeights = new Matrix(prunedMatrixPHI.getRowDimension(), prunedMatrixPHI.getRowDimension(), 0.0);
                int k4 = 0;
                while (k4 < matrixIRLSWeights.getRowDimension()) {
                    matrixIRLSWeights.set(k4, k4, vectorY.get(k4, 0) * (1.0 - vectorY.get(k4, 0)));
                    ++k4;
                }
                Matrix matrixHessian = prunedMatrixPHI.transpose().times(matrixIRLSWeights).times(prunedMatrixPHI);
                matrixHessian.plusEquals(matrixAlphaDiag);
                Matrix vectorE = new Matrix(vectorY.getRowDimension(), 1, 0.0);
                int k5 = 0;
                while (k5 < vectorY.getRowDimension()) {
                    vectorE.set(k5, 0, (double)t[k5] - vectorY.get(k5, 0));
                    ++k5;
                }
                Matrix vectorPenalty = (Matrix)prunedVectorAlpha.clone();
                int k6 = 0;
                while (k6 < vectorPenalty.getRowDimension()) {
                    vectorPenalty.set(k6, 0, vectorPenalty.get(k6, 0) * prunedVectorWeights.get(k6, 0));
                    ++k6;
                }
                Matrix vectorGradient = prunedMatrixPHI.transpose().times(vectorE).minus(vectorPenalty);
                LogService.logMessage("(IRLS) error = " + new Double(error).toString(), 2);
                SECholeskyDecomposition CD = new SECholeskyDecomposition(matrixHessian.getArray());
                matrixU = CD.getPTR().times(CD.getL());
                matrixUInv = matrixU.inverse();
                if (j >= 2 && vectorGradient.normF() / (double)prunedVectorWeights.getRowDimension() < minGradientChange) {
                    LogService.logMessage("(IRLS) converged after " + j + " iterations, error = " + error, 2);
                    break;
                }
                Matrix vectorDeltaWeights = matrixUInv.transpose().times(matrixUInv.times(vectorGradient));
                for (double lambda = 1.0; lambda > minLambda; lambda /= 2.0) {
                    Matrix vectorNewWeights = ((Matrix)prunedVectorWeights.clone()).plus(vectorDeltaWeights.times(lambda));
                    vectorY = prunedMatrixPHI.times(vectorNewWeights);
                    int k7 = 0;
                    while (k7 < vectorY.getRowDimension()) {
                        vectorY.set(k7, 0, this.sigmoid(vectorY.get(k7, 0)));
                        ++k7;
                    }
                    dataTerm = 0.0;
                    k7 = 0;
                    while (k7 < t.length) {
                        dataTerm = t[k7] == 1 ? (dataTerm -= Math.log(vectorY.get(k7, 0))) : (dataTerm -= Math.log(1.0 - vectorY.get(k7, 0)));
                        ++k7;
                    }
                    penaltyTerm = 0.0;
                    k7 = 0;
                    while (k7 < prunedVectorAlpha.getRowDimension()) {
                        penaltyTerm += prunedVectorAlpha.get(k7, 0) * vectorNewWeights.get(k7, 0) * vectorNewWeights.get(k7, 0);
                        ++k7;
                    }
                    double error_new = (dataTerm + penaltyTerm / 2.0) / (double)problem.getProblemSize();
                    if (error_new > error) {
                        LogService.logMessage("(IRLS) max overshot, backing off", 2);
                        continue;
                    }
                    prunedVectorWeights = vectorNewWeights;
                    break;
                }
                ++j;
            }
            double[] diagSIGMA = new double[matrixUInv.getRowDimension()];
            j = 0;
            while (j < diagSIGMA.length) {
                double value = 0.0;
                int k8 = 0;
                while (k8 < diagSIGMA.length) {
                    value += matrixUInv.get(k8, j) * matrixUInv.get(k8, j);
                    ++k8;
                }
                diagSIGMA[j] = value;
                ++j;
            }
            double[] gammas = new double[diagSIGMA.length];
            j = 0;
            while (j < gammas.length) {
                gammas[j] = 1.0 - prunedVectorAlpha.get(j, 0) * diagSIGMA[j];
                ++j;
            }
            double[] logAlphas = new double[prunedVectorAlpha.getRowDimension()];
            j = 0;
            while (j < logAlphas.length) {
                logAlphas[j] = Math.log(prunedVectorAlpha.get(j, 0));
                ++j;
            }
            j = 0;
            while (j < prunedVectorAlpha.getRowDimension()) {
                double newAlpha = gammas[j] / (prunedVectorWeights.get(j, 0) * prunedVectorWeights.get(j, 0));
                prunedVectorAlpha.set(j, 0, newAlpha);
                ++j;
            }
            double maxLogAlphaChange = 0.0;
            j = 0;
            while (j < logAlphas.length) {
                double change = Math.abs(logAlphas[j] - Math.log(prunedVectorAlpha.get(j, 0)));
                if (change > maxLogAlphaChange) {
                    maxLogAlphaChange = change;
                }
                ++j;
            }
            if (maxLogAlphaChange < this.parameter.min_delta_log_alpha) break;
            j = 0;
            while (j < prunedVectorAlpha.getRowDimension()) {
                vectorAlpha.set(unprunedIndicesArray[j], 0, prunedVectorAlpha.get(j, 0));
                ++j;
            }
            ++i;
        }
        double[] finalWeights = new double[unprunedIndicesArray.length];
        KernelBasisFunction[] finalKernels = new KernelBasisFunction[unprunedIndicesArray.length];
        boolean bias = false;
        j = 0;
        while (j < unprunedIndicesArray.length) {
            finalWeights[j] = prunedVectorWeights.get(j, 0);
            if (unprunedIndicesArray[j] == 0) {
                bias = true;
                finalKernels[j] = new KernelBasisFunction(new KernelRadial());
            } else {
                finalKernels[j] = kernels[unprunedIndicesArray[j]];
            }
            ++j;
        }
        Model model = new Model(finalWeights, finalKernels, bias, false);
        return model;
    }

    public double sigmoid(double x) {
        return 1.0 / (1.0 + Math.exp(-x));
    }

    public String toString() {
        return "Classification-RVM";
    }
}

