/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.myGP;

import Jama.Matrix;
import edu.udo.cs.myGP.GPBase;
import edu.udo.cs.myGP.Model;
import edu.udo.cs.myGP.Parameter;
import edu.udo.cs.myGP.RegressionProblem;
import edu.udo.cs.myRVM.Kernel.Kernel;
import edu.udo.cs.yale.tools.LogService;
import edu.udo.cs.yale.tools.Tools;
import java.util.TreeSet;

public class Regression
extends GPBase {
    public Regression(RegressionProblem problem, Parameter parameter) {
        super(problem, parameter);
    }

    private double scalarProduct(double[][] x, double[][] y, int d) throws Exception {
        if (x.length < d || y.length < d) {
            throw new Exception("At least one vector has a too small dimension!");
        }
        double result = 0.0;
        int i = 0;
        while (i < d) {
            result += x[i][0] * y[i][0];
            ++i;
        }
        return result;
    }

    private void swapRowsAndColumns(double[][] A, int i, int j) {
        int n = A[0].length;
        double[] tr = A[i];
        A[i] = A[j];
        A[j] = tr;
        int k = 0;
        while (k < n) {
            double ts = A[k][i];
            A[k][i] = A[k][j];
            A[k][j] = ts;
            ++k;
        }
    }

    public Model learn() throws Exception {
        RegressionProblem problem = (RegressionProblem)this.problem;
        int numExamples = problem.getProblemSize();
        int inputDim = problem.getInputDimension();
        double[][] x = problem.getInputVectors();
        double[][] y = problem.getTargetVectors();
        Kernel kernel = problem.getKernel();
        int dMax = this.parameter.maxBasisVectors + 1;
        int d = 0;
        Matrix alpha = new Matrix(dMax, 1);
        Matrix C = new Matrix(dMax, dMax);
        Matrix s = new Matrix(dMax, 1);
        Matrix Q = new Matrix(dMax, dMax);
        Matrix k = new Matrix(dMax, 1);
        Matrix e = new Matrix(dMax, 1);
        Matrix u = new Matrix(dMax, 1);
        double[][] basisVectors = new double[dMax][inputDim];
        Matrix C_times_k = new Matrix(dMax, 1);
        Matrix t = new Matrix(dMax, 1);
        double k_star = 0.0;
        double nabla = 0.0;
        int gamma_projections = 0;
        int kl_projections = 0;
        int geometrical_projections = 0;
        int i = 0;
        while (i < numExamples) {
            int jj;
            LogService.logMessage("* processing input vector nr: [" + i + "]", 2);
            double[] x_new = x[i];
            double y_new = y[i][0];
            int j = 0;
            while (j < d) {
                k.getArray()[j][0] = kernel.eval(basisVectors[j], x_new);
                ++j;
            }
            k_star = kernel.eval(x_new, x_new);
            double m = this.scalarProduct(k.getArray(), alpha.getArray(), d);
            C_times_k = C.times(k);
            double sigma_2 = k_star + this.scalarProduct(k.getArray(), C_times_k.getArray(), d);
            double q = (y_new - m) / (problem.sigma_0_2 + sigma_2);
            double r = -1.0 / (problem.sigma_0_2 + sigma_2);
            e = Q.times(k);
            double gamma = k_star - this.scalarProduct(k.getArray(), e.getArray(), d);
            Matrix Gram = new Matrix(d, d);
            int ii = 0;
            while (ii < d) {
                jj = 0;
                while (jj < d) {
                    Gram.getArray()[ii][jj] = kernel.eval(basisVectors[ii], basisVectors[jj]);
                    ++jj;
                }
                ++ii;
            }
            if (gamma < 0.0) {
                LogService.logMessage(" -> OOOOPPS: gamma < 0 !!! [gamma = " + gamma + "]", 2);
            }
            if (gamma < this.parameter.epsilon_tol) {
                LogService.logMessage(" -> gamma induced projection [gamma = " + gamma + "]", 2);
                nabla = 1.0 / (1.0 + gamma * r);
                s = C_times_k.plus(e);
                alpha = alpha.plus(s.times(q * nabla));
                C = C.plus(s.times(s.transpose()).times(r * nabla));
                ++gamma_projections;
            } else {
                LogService.logMessage(" -> adding new bv [nr = " + (d + 1) + "]", 2);
                int j2 = 0;
                while (j2 < dMax) {
                    u.getArray()[j2][0] = 0.0;
                    ++j2;
                }
                u.getArray()[d][0] = 1.0;
                s = C_times_k.plus(u);
                alpha = alpha.plus(s.times(q));
                C = C.plus(s.times(s.transpose()).times(r));
                t = e.minus(u);
                Q = Q.plus(t.times(t.transpose()).times(1.0 / gamma));
                basisVectors[d] = x_new;
                Gram = new Matrix(++d, d);
                ii = 0;
                while (ii < d) {
                    jj = 0;
                    while (jj < d) {
                        Gram.getArray()[ii][jj] = kernel.eval(basisVectors[ii], basisVectors[jj]);
                        ++jj;
                    }
                    ++ii;
                }
                Matrix L = Gram.chol().getL();
                Matrix invL = L.inverse();
                Q.setMatrix(0, d - 1, 0, d - 1, invL.transpose().times(invL));
            }
            if (d >= dMax) {
                LogService.logMessage(" -> maximum number of bvs exceeded, projecting...", 2);
                int min_index = ((Score)this.getMinScoresKLApprox(alpha, C, Q, d).first()).getIndex();
                this.deleteBV(alpha, C, Q, basisVectors, d - 1, min_index);
                --d;
                ++kl_projections;
            }
            while (d > 0) {
                Score minScore = (Score)this.getMinScoresGeometrical(alpha, C, Q, d).first();
                LogService.logMessage(" -> geometrical min score: " + minScore.getScore(), 2);
                if (minScore.getScore() > this.parameter.geometrical_tol) break;
                LogService.logMessage(" -> pruning based on geometrical score...", 2);
                this.deleteBV(alpha, C, Q, basisVectors, d - 1, minScore.getIndex());
                --d;
                ++geometrical_projections;
            }
            LogService.logMessage("", 2);
            ++i;
        }
        LogService.logMessage("Number of bvs: [" + d + "]" + Tools.getLineSeparator(), 2);
        LogService.logMessage("Number of gamma-projections      : [" + gamma_projections + "]", 2);
        LogService.logMessage("Number of kl-projections         : [" + kl_projections + "]", 2);
        LogService.logMessage("Number of geometrical-projections: [" + geometrical_projections + "]", 2);
        return new Model(kernel, basisVectors, alpha.getMatrix(0, d - 1, 0, 0), C.getMatrix(0, d - 1, 0, d - 1), Q.getMatrix(0, d - 1, 0, d - 1), d, true);
    }

    private TreeSet getMinScoresKLApprox(Matrix alpha, Matrix C, Matrix Q, int d) {
        TreeSet<Score> scores = new TreeSet<Score>();
        int j = 0;
        while (j < d) {
            double score = alpha.getArray()[j][0] * alpha.getArray()[j][0] / (Q.getArray()[j][j] + C.getArray()[j][j]);
            scores.add(new Score(score, j));
            ++j;
        }
        return scores;
    }

    private TreeSet getMinScoresGeometrical(Matrix alpha, Matrix C, Matrix Q, int d) {
        TreeSet<Score> scores = new TreeSet<Score>();
        int j = 0;
        while (j < d) {
            double score = 1.0 / Q.getArray()[j][j];
            scores.add(new Score(score, j));
            ++j;
        }
        return scores;
    }

    private void deleteBV(Matrix alpha, Matrix C, Matrix Q, double[][] basisVectors, int d, int index) {
        int inputDim = basisVectors[0].length;
        int dMax = basisVectors.length;
        double[] t_row = new double[inputDim];
        double t_scalar = 0.0;
        t_scalar = alpha.getArray()[index][0];
        alpha.getArray()[index][0] = alpha.getArray()[d][0];
        alpha.getArray()[d][0] = t_scalar;
        this.swapRowsAndColumns(C.getArray(), index, d);
        this.swapRowsAndColumns(Q.getArray(), index, d);
        t_row = basisVectors[index];
        basisVectors[index] = basisVectors[d];
        basisVectors[d] = t_row;
        double alpha_star = alpha.getArray()[d][0];
        double c_star = C.getArray()[d][d];
        double q_star = Q.getArray()[d][d];
        double[][] C_star = new double[dMax][1];
        Matrix vector_C_star = new Matrix(C_star);
        double[][] Q_star = new double[dMax][1];
        Matrix vector_Q_star = new Matrix(Q_star);
        int j = 0;
        while (j < d) {
            C_star[j][0] = C.getArray()[j][d];
            Q_star[j][0] = Q.getArray()[j][d];
            ++j;
        }
        alpha.minusEquals(vector_Q_star.plus(vector_C_star).times(alpha_star / (c_star + q_star)));
        C.plusEquals(vector_Q_star.times(vector_Q_star.transpose()).times(1.0 / q_star));
        C.minusEquals(vector_Q_star.plus(vector_C_star).times(vector_Q_star.plus(vector_C_star).transpose()).times(1.0 / (q_star + c_star)));
        Q.minusEquals(vector_Q_star.times(vector_Q_star.transpose()).times(1.0 / q_star));
        alpha.getArray()[d][0] = 0.0;
        j = 0;
        while (j <= d) {
            Q.getArray()[d][j] = 0.0;
            C.getArray()[d][j] = 0.0;
            Q.getArray()[j][d] = 0.0;
            C.getArray()[j][d] = 0.0;
            ++j;
        }
    }

    public String toString() {
        return "Gauss-Regression-GP";
    }

    protected static class Score
    implements Comparable {
        double score;
        int index;

        Score(double score, int index) {
            this.score = score;
            this.index = index;
        }

        public int compareTo(Object o) throws NullPointerException {
            Score s2 = (Score)o;
            if (this.score < s2.getScore()) {
                return -1;
            }
            if (this.score == s2.getScore()) {
                return 0;
            }
            return 1;
        }

        public boolean equals(Object o) {
            if (o instanceof Score) {
                return this.score == ((Score)o).score;
            }
            return false;
        }

        public int hashCode() {
            return Double.valueOf(this.score).hashCode();
        }

        public double getScore() {
            return this.score;
        }

        public int getIndex() {
            return this.index;
        }
    }
}

