/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.myRVM;

import Jama.Matrix;
import edu.udo.cs.myRVM.Kernel.KernelBasisFunction;
import edu.udo.cs.myRVM.Kernel.KernelRadial;
import edu.udo.cs.myRVM.Model;
import edu.udo.cs.myRVM.Parameter;
import edu.udo.cs.myRVM.RVMBase;
import edu.udo.cs.myRVM.RegressionProblem;
import edu.udo.cs.myRVM.Util.SECholeskyDecomposition;
import edu.udo.cs.yale.tools.LogService;
import java.util.Iterator;
import java.util.LinkedList;

public class RVMRegression
extends RVMBase {
    public RVMRegression(RegressionProblem problem, Parameter parameter) {
        super(problem, parameter);
    }

    public Model learn() {
        int i;
        RegressionProblem problem = (RegressionProblem)this.problem;
        int numExamples = problem.getProblemSize();
        int numBases = numExamples + 1;
        int monIts = 1;
        double initBeta = 1.0 / Math.pow(this.parameter.initSigma, 2.0);
        double[][] x = problem.getInputVectors();
        KernelBasisFunction[] kernels = problem.getKernels();
        double[][] PHI = new double[numExamples][numBases];
        int j = 0;
        while (j < numBases - 1) {
            i = 0;
            while (i < numExamples) {
                PHI[i][j + 1] = kernels[j + 1].eval(x[i]);
                ++i;
            }
            ++j;
        }
        i = 0;
        while (i < numExamples) {
            PHI[i][0] = 1.0;
            ++i;
        }
        double[] alphas = new double[numBases];
        j = 0;
        while (j < numBases) {
            alphas[j] = this.parameter.initAlpha;
            ++j;
        }
        Matrix matrixPHI = new Matrix(PHI);
        Matrix vectorT = new Matrix(problem.getTargetVectors());
        Matrix vectorAlpha = new Matrix(alphas, numBases);
        Matrix vectorPHI_T = matrixPHI.transpose().times(vectorT);
        LinkedList<Integer> unprunedIndicesList = null;
        int[] unprunedIndicesArray = null;
        Matrix prunedVectorWeights = null;
        i = 1;
        while (i <= this.parameter.maxIterations) {
            unprunedIndicesList = new LinkedList<Integer>();
            j = 0;
            while (j < numBases) {
                if (!(vectorAlpha.get(j, 0) >= this.parameter.alpha_max)) {
                    unprunedIndicesList.add(j);
                }
                ++j;
            }
            unprunedIndicesArray = new int[unprunedIndicesList.size()];
            Iterator iter = unprunedIndicesList.iterator();
            j = 0;
            while (j < unprunedIndicesList.size()) {
                unprunedIndicesArray[j] = (Integer)iter.next();
                ++j;
            }
            if (i > 100) {
                monIts = 5;
            }
            if (i > 1000) {
                monIts = 100;
            }
            if (i % monIts == 0) {
                LogService.logMessage("it: " + i + " ; bases: " + unprunedIndicesArray.length, 2);
            }
            Matrix prunedMatrixPHI = matrixPHI.getMatrix(0, matrixPHI.getRowDimension() - 1, unprunedIndicesArray);
            Matrix prunedVectorPHI_T = vectorPHI_T.getMatrix(unprunedIndicesArray, 0, 0);
            Matrix prunedVectorAlpha = vectorAlpha.getMatrix(unprunedIndicesArray, 0, 0);
            Matrix matrixAlphaDiag = new Matrix(prunedVectorAlpha.getRowDimension(), prunedVectorAlpha.getRowDimension(), 0.0);
            j = 0;
            while (j < prunedVectorAlpha.getRowDimension()) {
                matrixAlphaDiag.set(j, j, prunedVectorAlpha.get(j, 0));
                ++j;
            }
            Matrix matrixSIGMAInv = prunedMatrixPHI.transpose().times(prunedMatrixPHI);
            matrixSIGMAInv.timesEquals(initBeta);
            matrixSIGMAInv.plusEquals(matrixAlphaDiag);
            SECholeskyDecomposition CD = new SECholeskyDecomposition(matrixSIGMAInv.getArray());
            Matrix matrixU = CD.getPTR().times(CD.getL());
            Matrix matrixUInv = matrixU.inverse();
            prunedVectorWeights = matrixUInv.transpose().times(matrixUInv.times(prunedVectorPHI_T)).times(initBeta);
            double[] diagSIGMA = new double[matrixUInv.getRowDimension()];
            j = 0;
            while (j < diagSIGMA.length) {
                double value = 0.0;
                int k = 0;
                while (k < diagSIGMA.length) {
                    value += matrixUInv.get(k, j) * matrixUInv.get(k, j);
                    ++k;
                }
                diagSIGMA[j] = value;
                ++j;
            }
            double[] gammas = new double[diagSIGMA.length];
            j = 0;
            while (j < gammas.length) {
                gammas[j] = 1.0 - prunedVectorAlpha.get(j, 0) * diagSIGMA[j];
                ++j;
            }
            double[] logAlphas = new double[prunedVectorAlpha.getRowDimension()];
            j = 0;
            while (j < logAlphas.length) {
                logAlphas[j] = Math.log(prunedVectorAlpha.get(j, 0));
                ++j;
            }
            j = 0;
            while (j < prunedVectorAlpha.getRowDimension()) {
                double newAlpha = gammas[j] / (prunedVectorWeights.get(j, 0) * prunedVectorWeights.get(j, 0));
                prunedVectorAlpha.set(j, 0, newAlpha);
                ++j;
            }
            double maxLogAlphaChange = 0.0;
            j = 0;
            while (j < logAlphas.length) {
                double change = Math.abs(logAlphas[j] - Math.log(prunedVectorAlpha.get(j, 0)));
                if (change > maxLogAlphaChange) {
                    maxLogAlphaChange = change;
                }
                ++j;
            }
            if (maxLogAlphaChange < this.parameter.min_delta_log_alpha) break;
            double dataError = 0.0;
            Matrix dataDelta = vectorT.minus(prunedMatrixPHI.times(prunedVectorWeights));
            j = 0;
            while (j < numExamples) {
                dataError += dataDelta.get(j, 0) * dataDelta.get(j, 0);
                ++j;
            }
            double sumGammas = 0.0;
            j = 0;
            while (j < gammas.length) {
                sumGammas += gammas[j];
                ++j;
            }
            initBeta = ((double)numExamples - sumGammas) / dataError;
            j = 0;
            while (j < prunedVectorAlpha.getRowDimension()) {
                vectorAlpha.set(unprunedIndicesArray[j], 0, prunedVectorAlpha.get(j, 0));
                ++j;
            }
            ++i;
        }
        double[] finalWeights = new double[unprunedIndicesArray.length];
        KernelBasisFunction[] finalKernels = new KernelBasisFunction[unprunedIndicesArray.length];
        boolean bias = false;
        j = 0;
        while (j < unprunedIndicesArray.length) {
            finalWeights[j] = prunedVectorWeights.get(j, 0);
            if (unprunedIndicesArray[j] == 0) {
                bias = true;
                finalKernels[j] = new KernelBasisFunction(new KernelRadial());
            } else {
                finalKernels[j] = kernels[unprunedIndicesArray[j]];
            }
            ++j;
        }
        Model model = new Model(finalWeights, finalKernels, bias, true);
        return model;
    }

    public String toString() {
        return "Regression-RVM";
    }
}

