/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.functions.kernel.logistic;

import com.rapidminer.operator.Operator;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.examples.SVMExample;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.examples.SVMExamples;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.kernel.Kernel;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.svm.SVMInterface;
import com.rapidminer.parameter.UndefinedParameterError;

public class KLR
implements SVMInterface {
    protected Kernel kernel;
    protected SVMExamples examples;
    int n;
    double[] target;
    int n1;
    int n2;
    double[] alphas;
    double[] hCache;
    boolean[] atBound;
    int iUp;
    int iLow;
    double b;
    double bUp;
    double bLow;
    double tol = 0.001;
    double C = 1.0;
    double epsilon;
    double mu;
    int maxIterations = 100000;

    public KLR() {
    }

    public KLR(Operator paramOperator) throws UndefinedParameterError {
        this.C = paramOperator.getParameterAsDouble("C");
        this.tol = paramOperator.getParameterAsDouble("convergence_epsilon");
        this.maxIterations = paramOperator.getParameterAsInt("max_iterations");
    }

    @Override
    public void init(Kernel new_kernel, SVMExamples new_examples) {
        this.kernel = new_kernel;
        this.examples = new_examples;
        this.target = this.examples.get_ys();
        this.alphas = this.examples.get_alphas();
        this.n = this.examples.count_examples();
        this.mu = this.epsilon = this.C * 1.0E-10;
        this.n1 = this.examples.count_pos_examples();
        this.n2 = this.n - this.n1;
    }

    final double dG(double alpha) {
        return Math.log(alpha / (this.C - alpha));
    }

    final double dPhi(double t, int i, int j, double ai, double aj, double Kii, double Kij, double Kjj) {
        double result = 0.0;
        double ydG = 0.0;
        if (this.target[i] > 0.0) {
            result = Kii - Kij;
            ydG = this.dG(ai + t) - this.dG(ai);
        } else {
            result = Kij - Kii;
            ydG = this.dG(ai) - this.dG(ai - t);
        }
        if (this.target[j] > 0.0) {
            result = Kjj - Kij;
            ydG -= this.dG(aj - t) - this.dG(aj);
        } else {
            result = Kij - Kjj;
            ydG -= this.dG(aj) - this.dG(aj + t);
        }
        result = t * (Kii - 2.0 * Kij + Kjj);
        result += ydG;
        return result += this.hCache[i] - this.hCache[j];
    }

    final double d2Phi(double t, int i, int j, double ai, double aj, double Kii, double Kij, double Kjj) {
        double atilde = this.target[i] > 0.0 ? ai + t : ai - t;
        double result = this.C / (atilde * (this.C - atilde));
        atilde = this.target[j] > 0.0 ? aj - t : aj + t;
        result += this.C / (atilde * (this.C - atilde));
        return result += Kii - 2.0 * Kij + Kjj;
    }

    protected boolean takeStep(int i, int j) {
        double Hj;
        double Hi;
        double t_tmp;
        double t_max;
        double t_min;
        double[] kernel_row_i = this.kernel.get_row(i);
        double[] kernel_row_j = this.kernel.get_row(j);
        double aio = this.alphas[i];
        double ajo = this.alphas[j];
        double yi = this.target[i];
        double yj = this.target[j];
        double Hio = this.hCache[i];
        double Hjo = this.hCache[j];
        double Kii = kernel_row_i[i];
        double Kij = kernel_row_i[j];
        double Kjj = kernel_row_j[j];
        int takestepFlag = 1;
        if (yi > 0.0) {
            t_min = this.mu / 2.0 - aio;
            t_max = this.C - this.mu / 2.0 - aio;
        } else {
            t_max = aio - this.mu / 2.0;
            t_min = aio - (this.C - this.mu / 2.0);
        }
        if (yj > 0.0) {
            t_tmp = ajo - this.mu / 2.0;
            if (t_tmp < t_max) {
                t_max = t_tmp;
            }
            if ((t_tmp = ajo - (this.C - this.mu / 2.0)) > t_min) {
                t_min = t_tmp;
            }
        } else {
            t_tmp = this.mu / 2.0 - ajo;
            if (t_tmp > t_min) {
                t_min = t_tmp;
            }
            if ((t_tmp = this.C - this.mu / 2.0 - ajo) < t_max) {
                t_max = t_tmp;
            }
        }
        if (t_max - t_min <= this.epsilon) {
            return false;
        }
        double t = 0.0;
        double the_dPhi = Hio - Hjo;
        double the_d2Phi = Kii - 2.0 * Kij + Kjj + this.C / (aio * (this.C - aio)) + this.C / (ajo * (this.C - ajo));
        double dPhi_left = 0.0;
        double d2Phi_left = 0.0;
        double dPhi_right = 0.0;
        double d2Phi_right = 0.0;
        double t_left = 0.0;
        double t_right = 0.0;
        if (the_dPhi > 0.0) {
            dPhi_left = this.dPhi(t_min, i, j, aio, ajo, Kii, Kij, Kjj);
            d2Phi_left = this.d2Phi(t_min, i, j, aio, ajo, Kii, Kij, Kjj);
            if (dPhi_left < 0.0) {
                t_left = t_min;
                t_right = t;
                dPhi_right = the_dPhi;
                d2Phi_right = the_d2Phi;
            } else {
                t = t_min;
                takestepFlag = 2;
            }
        } else if (the_dPhi < 0.0) {
            dPhi_right = this.dPhi(t_max, i, j, aio, ajo, Kii, Kij, Kjj);
            d2Phi_right = this.d2Phi(t_max, i, j, aio, ajo, Kii, Kij, Kjj);
            if (dPhi_right > 0.0) {
                t_left = t;
                t_right = t_max;
                dPhi_left = the_dPhi;
                d2Phi_left = the_d2Phi;
            } else {
                t = t_max;
                takestepFlag = 2;
            }
        } else {
            return false;
        }
        double ai = 0.0;
        double aj = 0.0;
        if (takestepFlag == 1) {
            double t0;
            if (Math.abs(dPhi_left) < Math.abs(dPhi_right)) {
                t0 = t_left;
                the_dPhi = dPhi_left;
                the_d2Phi = d2Phi_left;
            } else {
                t0 = t_right;
                the_dPhi = dPhi_right;
                the_d2Phi = d2Phi_right;
            }
            do {
                double dt;
                if ((t = t0 + (dt = -the_dPhi / the_d2Phi)) <= t_left || t >= t_right) {
                    t = (t_left + t_right) / 2.0;
                }
                ai = aio + t / yi;
                aj = ajo - t / yj;
                Hi = Hio + t * (Kii - Kij) + yi * (Math.log(ai / (this.C - ai)) - Math.log(aio / (this.C - aio)));
                Hj = Hjo + t * (Kij - Kjj) + yj * (Math.log(aj / (this.C - aj)) - Math.log(ajo / (this.C - ajo)));
                the_dPhi = Hi - Hj;
                the_d2Phi = Kii - 2.0 * Kij + Kjj + this.C / (ai * (this.C - ai)) + this.C / (aj * (this.C - aj));
                if (the_dPhi * dPhi_left > 0.0) {
                    t_left = t;
                    dPhi_left = the_dPhi;
                } else {
                    t_right = t;
                    dPhi_right = the_dPhi;
                }
                t0 = t;
            } while (Math.abs(the_dPhi) > 0.1 * this.tol && t_left + this.epsilon < t_right);
        } else if (takestepFlag == 2) {
            ai = aio + t / yi;
            aj = ajo - t / yj;
            Hi = Hio + t * (Kii - Kij) + yi * (Math.log(ai / (this.C - ai)) - Math.log(aio / (this.C - aio)));
            Hj = Hjo + t * (Kij - Kjj) + yj * (Math.log(aj / (this.C - aj)) - Math.log(ajo / (this.C - ajo)));
        }
        if (t == 0.0) {
            return false;
        }
        this.alphas[i] = ai;
        this.alphas[j] = aj;
        if (ai <= this.mu) {
            if (this.target[i] > 0.0 && this.target[j] > 0.0 || this.target[i] < 0.0 && this.target[j] < 0.0) {
                int n = j;
                this.alphas[n] = this.alphas[n] - (this.mu - this.alphas[i]);
            } else {
                int n = j;
                this.alphas[n] = this.alphas[n] + (this.mu - this.alphas[i]);
            }
            this.alphas[i] = this.mu;
            this.atBound[i] = true;
        } else if (ai >= this.C - this.mu) {
            if (this.target[i] > 0.0 && this.target[j] > 0.0 || this.target[i] < 0.0 && this.target[j] < 0.0) {
                int n = j;
                this.alphas[n] = this.alphas[n] - (this.C - this.mu - this.alphas[i]);
            } else {
                int n = j;
                this.alphas[n] = this.alphas[n] + (this.C - this.mu - this.alphas[i]);
            }
            this.alphas[i] = this.C - this.mu;
            this.atBound[i] = true;
        } else {
            this.atBound[i] = false;
        }
        if (aj <= this.mu) {
            if (this.target[i] > 0.0 && this.target[j] > 0.0 || this.target[i] < 0.0 && this.target[j] < 0.0) {
                int n = i;
                this.alphas[n] = this.alphas[n] - (this.mu - this.alphas[j]);
            } else {
                int n = i;
                this.alphas[n] = this.alphas[n] + (this.mu - this.alphas[j]);
            }
            this.alphas[j] = this.mu;
            this.atBound[j] = true;
        } else if (aj >= this.C - this.mu) {
            if (this.target[i] > 0.0 && this.target[j] > 0.0 || this.target[i] < 0.0 && this.target[j] < 0.0) {
                int n = i;
                this.alphas[n] = this.alphas[n] - (this.C - this.mu - this.alphas[j]);
            } else {
                int n = i;
                this.alphas[n] = this.alphas[n] + (this.C - this.mu - this.alphas[j]);
            }
            this.alphas[j] = this.C - this.mu;
            this.atBound[j] = true;
        } else {
            this.atBound[j] = false;
        }
        t = ((this.alphas[i] - aio) * yi + (ajo - this.alphas[j]) * yj) / 2.0;
        Hi = Hio + t * (Kii - Kij) + yi * (Math.log(this.alphas[i] / (this.C - this.alphas[i])) - Math.log(aio / (this.C - aio)));
        Hj = Hjo + t * (Kij - Kjj) + yj * (Math.log(this.alphas[j] / (this.C - this.alphas[j])) - Math.log(ajo / (this.C - ajo)));
        for (int k = 0; k < this.n; ++k) {
            int n = k;
            this.hCache[n] = this.hCache[n] + t * (kernel_row_i[k] - kernel_row_j[k]);
        }
        this.hCache[i] = Hi;
        this.hCache[j] = Hj;
        this.bUp = Double.NEGATIVE_INFINITY;
        this.bLow = Double.POSITIVE_INFINITY;
        this.iUp = 0;
        this.iLow = 0;
        for (int l = 0; l < this.n; ++l) {
            if (this.atBound[l]) continue;
            if (this.hCache[l] > this.bUp) {
                this.bUp = this.hCache[l];
                this.iUp = l;
            }
            if (!(this.hCache[l] < this.bLow)) continue;
            this.bLow = this.hCache[l];
            this.iLow = l;
        }
        return true;
    }

    public void klr() {
        int i;
        this.examples.clearAlphas();
        this.hCache = new double[this.n];
        this.atBound = new boolean[this.n];
        double alpha_pos = this.C / (double)this.n1;
        double alpha_neg = this.C / (double)this.n2;
        for (i = 0; i < this.n; ++i) {
            this.alphas[i] = this.target[i] > 0.0 ? alpha_pos : alpha_neg;
            this.atBound[i] = false;
        }
        this.bUp = Double.NEGATIVE_INFINITY;
        this.bLow = Double.POSITIVE_INFINITY;
        this.iUp = 0;
        this.iLow = 0;
        for (i = 0; i < this.n; ++i) {
            double sum_pos_K = 0.0;
            double sum_neg_K = 0.0;
            double[] kernel_row = this.kernel.get_row(i);
            for (int j = 0; j < this.n; ++j) {
                if (this.target[j] > 0.0) {
                    sum_pos_K += kernel_row[j];
                    continue;
                }
                sum_neg_K += kernel_row[j];
            }
            this.hCache[i] = alpha_pos * sum_pos_K - alpha_neg * sum_neg_K + this.target[i] * this.dG(this.alphas[i]);
            if (this.hCache[i] > this.bUp) {
                this.bUp = this.hCache[i];
                this.iUp = i;
            }
            if (!(this.hCache[i] < this.bLow)) continue;
            this.bLow = this.hCache[i];
            this.iLow = i;
        }
        int it = this.maxIterations;
        while (true) {
            boolean Flag;
            if (2.0 * this.tol < this.bUp - this.bLow) {
                --it;
                Flag = this.takeStep(this.iLow, this.iUp);
                if (Flag) continue;
            }
            int numChange = 0;
            for (i = 0; i < this.n; ++i) {
                if (!this.atBound[i]) continue;
                double Hi = this.hCache[i];
                if (Math.abs(Hi - this.bLow) >= Math.abs(Hi - this.bUp)) {
                    --it;
                    Flag = this.takeStep(i, this.iLow);
                    if (!Flag) {
                        --it;
                        Flag = this.takeStep(i, this.iUp);
                    }
                } else {
                    --it;
                    Flag = this.takeStep(i, this.iUp);
                    if (!Flag) {
                        --it;
                        Flag = this.takeStep(i, this.iLow);
                    }
                }
                if (this.atBound[i]) continue;
                ++numChange;
            }
            if (numChange == 0 || it <= 0) break;
        }
        this.b = (this.bLow + this.bUp) / 2.0;
    }

    @Override
    public double predict(SVMExample sVMExample) {
        double the_sum = this.examples.get_b();
        for (int i = 0; i < this.n; ++i) {
            double alpha = this.alphas[i];
            if (alpha == 0.0) continue;
            SVMExample sv = this.examples.get_example(i);
            the_sum += this.alphas[i] * this.kernel.calculate_K(sv, sVMExample);
        }
        the_sum = 1.0 / (1.0 + Math.exp(-the_sum));
        return the_sum;
    }

    @Override
    public void predict(SVMExamples to_predict) {
        int size = to_predict.count_examples();
        for (int i = 0; i < size; ++i) {
            SVMExample sVMExample = to_predict.get_example(i);
            double prediction = this.predict(sVMExample);
            to_predict.set_y(i, prediction);
        }
    }

    @Override
    public void train() {
        this.klr();
        this.b = -this.b;
        this.examples.set_b(this.b);
        for (int i = 0; i < this.n; ++i) {
            if (!(this.target[i] < 0.0)) continue;
            this.alphas[i] = -this.alphas[i];
        }
    }

    @Override
    public double[] getWeights() {
        int dim = this.examples.get_dim();
        int examples_total = this.examples.count_examples();
        double[] w = new double[dim];
        for (int j = 0; j < dim; ++j) {
            w[j] = 0.0;
        }
        for (int i = 0; i < examples_total; ++i) {
            double[] x = this.examples.get_example(i).toDense(dim);
            double alpha = this.alphas[i];
            for (int j = 0; j < dim; ++j) {
                int n = j;
                w[n] = w[n] + alpha * x[j];
            }
        }
        return w;
    }

    @Override
    public double getB() {
        return this.examples.get_b();
    }
}

