/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.functions;

import Jama.Matrix;
import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.Model;
import edu.udo.cs.yale.operator.OperatorDescription;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.learner.AbstractLearner;
import edu.udo.cs.yale.operator.learner.LearnerCapability;
import edu.udo.cs.yale.operator.learner.functions.OrdinalLogisticRegressionModel;
import edu.udo.cs.yale.operator.parameter.ParameterType;
import edu.udo.cs.yale.operator.parameter.ParameterTypeDouble;
import java.util.Iterator;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class OrdinalLogisticRegression
extends AbstractLearner {
    private double incr = 10.0;
    private double decr = 2.0;

    public OrdinalLogisticRegression(OperatorDescription description) {
        super(description);
    }

    @Override
    public boolean supportsCapability(LearnerCapability lc) {
        if (lc == LearnerCapability.POLYNOMINAL_ATTRIBUTES) {
            return false;
        }
        if (lc == LearnerCapability.BINOMINAL_ATTRIBUTES) {
            return false;
        }
        if (lc == LearnerCapability.NUMERICAL_ATTRIBUTES) {
            return true;
        }
        if (lc == LearnerCapability.POLYNOMINAL_CLASS) {
            return false;
        }
        if (lc == LearnerCapability.BINOMINAL_CLASS) {
            return false;
        }
        return lc == LearnerCapability.NUMERICAL_CLASS;
    }

    private Matrix buildDesignMatrix(ExampleSet exampleSet) {
        Matrix x = new Matrix(exampleSet.size(), exampleSet.getAttributes().size());
        Iterator iterator = exampleSet.iterator();
        int i = 0;
        while (iterator.hasNext()) {
            Example example = (Example)iterator.next();
            int j = 0;
            for (Attribute attribute : exampleSet.getAttributes()) {
                double value = example.getValue(attribute);
                x.set(i, j, value);
                ++j;
            }
            ++i;
        }
        return x;
    }

    private Matrix getResponseVector(ExampleSet exampleSet) {
        int numberOfRows = exampleSet.size();
        Matrix y = new Matrix(numberOfRows, 1);
        Iterator iterator = exampleSet.iterator();
        int i = 0;
        while (iterator.hasNext()) {
            Example example = (Example)iterator.next();
            Double value = example.getLabel();
            y.set(i, 0, value);
            ++i;
        }
        return y;
    }

    private Matrix[] getResponseMatrices(ExampleSet exampleSet) {
        Matrix responseVector = this.getResponseVector(exampleSet);
        int n = responseVector.getRowDimension();
        int q = (int)this.max(responseVector) - (int)this.min(responseVector) + 1;
        Matrix[] responseMatrices = new Matrix[]{new Matrix(n, q - 1), new Matrix(n, q - 1)};
        int i = 0;
        while (i < n) {
            double value = responseVector.get(i, 0);
            if (value >= 1.0 && value < (double)q) {
                responseMatrices[0].set(i, (int)value - 1, 1.0);
            }
            if (value > 1.0 && value <= (double)q) {
                responseMatrices[1].set(i, (int)value - 2, 1.0);
            }
            ++i;
        }
        return responseMatrices;
    }

    private Matrix sumMatrix(Matrix x) {
        Matrix s = new Matrix(1, x.getColumnDimension());
        int i = 0;
        while (i < x.getRowDimension()) {
            int j = 0;
            while (j < x.getColumnDimension()) {
                s.set(0, j, s.get(0, j) + x.get(i, j));
                ++j;
            }
            ++i;
        }
        return s;
    }

    private Matrix getInitialEstimates(Matrix z, int m, int n) {
        Matrix tb = new Matrix(1, z.getColumnDimension() + m);
        Matrix s = this.sumMatrix(z);
        double sum = 0.0;
        int i = 0;
        while (i < s.getColumnDimension()) {
            s.set(0, i, sum += s.get(0, i));
            ++i;
        }
        i = 0;
        while (i < s.getColumnDimension()) {
            double s0i = s.get(0, i) / (double)n;
            s0i = Math.log(s0i / (1.0 - s0i));
            s.set(0, i, s0i);
            tb.set(0, i, s0i);
            ++i;
        }
        return tb;
    }

    private Matrix linkMatrix(Matrix x) {
        int i = 0;
        while (i < x.getRowDimension()) {
            int j = 0;
            while (j < x.getColumnDimension()) {
                double e = Math.exp(x.get(i, j));
                x.set(i, j, e / (1.0 + e));
                ++j;
            }
            ++i;
        }
        return x;
    }

    private Matrix mergeMatrix(Matrix a, Matrix b) {
        Matrix c = new Matrix(a.getRowDimension(), a.getColumnDimension() + b.getColumnDimension());
        int i = 0;
        while (i < a.getRowDimension()) {
            int j = 0;
            while (j < a.getColumnDimension()) {
                c.set(i, j, a.get(i, j));
                ++j;
            }
            j = 0;
            while (j < b.getColumnDimension()) {
                c.set(i, j + a.getColumnDimension(), b.get(i, j));
                ++j;
            }
            ++i;
        }
        return c;
    }

    private Matrix ones(int n) {
        Matrix x = new Matrix(n, 1);
        int i = 0;
        while (i < n) {
            x.set(i, 0, 1.0);
            ++i;
        }
        return x;
    }

    private Matrix diag(Matrix x) {
        int l = x.getRowDimension();
        Matrix d = new Matrix(l, l);
        int i = 0;
        while (i < l) {
            d.set(i, i, x.get(i, 0));
            ++i;
        }
        return d;
    }

    private double std(Matrix x) {
        double s = 0.0;
        int i = 0;
        while (i < x.getRowDimension()) {
            s += x.get(i, 0);
            ++i;
        }
        s /= (double)x.getRowDimension();
        double stdev = 0.0;
        int i2 = 0;
        while (i2 < x.getRowDimension()) {
            stdev += (x.get(i2, 0) - s) * (x.get(i2, 0) - s);
            ++i2;
        }
        return Math.sqrt(stdev / (double)(x.getRowDimension() - 1));
    }

    private double max(Matrix x) {
        double m = 0.0;
        int i = 0;
        while (i < x.getRowDimension()) {
            double v = x.get(i, 0);
            if (v > m) {
                m = v;
            }
            ++i;
        }
        return m;
    }

    private double min(Matrix x) {
        double m = 2.147483647E9;
        int i = 0;
        while (i < x.getRowDimension()) {
            double v = x.get(i, 0);
            if (v < m) {
                m = v;
            }
            ++i;
        }
        return m;
    }

    /*
     * Unable to fully structure code
     */
    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        tol = this.getParameterAsDouble("tolerance");
        x = this.buildDesignMatrix(exampleSet);
        x = x.times(-1.0);
        y = this.getResponseVector(exampleSet);
        maxy = this.max(y);
        miny = this.min(y);
        z = this.getResponseMatrices(exampleSet);
        tb = this.getInitialEstimates(z[0], x.getColumnDimension(), x.getRowDimension()).transpose();
        g = this.linkMatrix(this.mergeMatrix(z[0], x).times(tb));
        g1 = this.linkMatrix(this.mergeMatrix(z[1], x).times(tb));
        i = 0;
        while (i < g.getRowDimension()) {
            if (y.get(i, 0) == maxy) {
                g.set(i, 0, 1.0);
            }
            if (y.get(i, 0) == miny) {
                g1.set(i, 0, 0.0);
            }
            ++i;
        }
        p = g.minus(g1);
        dev = 0.0;
        i = 0;
        while (i < p.getRowDimension()) {
            dev += -2.0 * Math.log(p.get(i, 0));
            ++i;
        }
        v = g.arrayTimes(this.ones(g.getRowDimension()).minus(g)).arrayRightDivide(p);
        v1 = g1.arrayTimes(this.ones(g1.getRowDimension()).minus(g1)).arrayRightDivide(p);
        dlogp = this.mergeMatrix(this.diag(v).times(z[0]).minus(this.diag(v1).times(z[1])), this.diag(v.minus(v1)).times(x));
        dl = this.sumMatrix(dlogp).transpose();
        w = v.arrayTimes(this.ones(g.getRowDimension()).minus(g.times(2.0)));
        w1 = v1.arrayTimes(this.ones(g1.getRowDimension()).minus(g1.times(2.0)));
        d2l = this.mergeMatrix(z[0], x).transpose().times(this.diag(w)).times(this.mergeMatrix(z[0], x)).minus(this.mergeMatrix(z[1], x).transpose().times(this.diag(w1)).times(this.mergeMatrix(z[1], x))).minus(dlogp.transpose().times(dlogp));
        epsilon = this.std(new Matrix(d2l.getColumnPackedCopy(), d2l.getColumnPackedCopy().length)) / 1000.0;
        iteration = 0;
        tbold = null;
        devold = 0.0;
        while (Math.abs(dl.transpose().times(d2l.inverse()).times(dl).get(0, 0) / (double)dl.getRowDimension()) > tol) {
            block15: {
                ++iteration;
                tbold = tb;
                devold = dev;
                tb = tbold.minus(d2l.inverse().times(dl));
                g = this.linkMatrix(this.mergeMatrix(z[0], x).times(tb));
                g1 = this.linkMatrix(this.mergeMatrix(z[1], x).times(tb));
                i = 0;
                while (i < g.getRowDimension()) {
                    if (y.get(i, 0) == maxy) {
                        g.set(i, 0, 1.0);
                    }
                    if (y.get(i, 0) == miny) {
                        g1.set(i, 0, 0.0);
                    }
                    ++i;
                }
                p = g.minus(g1);
                dev = 0.0;
                i = 0;
                while (i < p.getRowDimension()) {
                    dev += -2.0 * Math.log(p.get(i, 0));
                    ++i;
                }
                if (!((dev - devold) / dl.transpose().times(tb.minus(tbold)).get(0, 0) < 0.0)) ** GOTO lbl82
                epsilon /= this.decr;
                break block15;
lbl-1000:
                // 1 sources

                {
                    if ((epsilon *= this.incr) > 1.0E15) {
                        System.err.println("No convergence!");
                    }
                    tb = tbold.minus(d2l.minus(this.diag(this.ones(d2l.getColumnDimension())).times(epsilon)).inverse().times(dl));
                    g = this.linkMatrix(this.mergeMatrix(z[0], x).times(tb));
                    g1 = this.linkMatrix(this.mergeMatrix(z[1], x).times(tb));
                    i = 0;
                    while (i < g.getRowDimension()) {
                        if (y.get(i, 0) == maxy) {
                            g.set(i, 0, 1.0);
                        }
                        if (y.get(i, 0) == miny) {
                            g1.set(i, 0, 0.0);
                        }
                        ++i;
                    }
                    p = g.minus(g1);
                    dev = 0.0;
                    i = 0;
                    while (i < p.getRowDimension()) {
                        dev += -2.0 * Math.log(p.get(i, 0));
                        ++i;
                    }
lbl82:
                    // 2 sources

                    ** while ((dev - devold) / dl.transpose().times((Matrix)tb.minus((Matrix)tbold)).get((int)0, (int)0) < 0.0)
                }
            }
            v = g.arrayTimes(this.ones(g.getRowDimension()).minus(g)).arrayRightDivide(p);
            v1 = g1.arrayTimes(this.ones(g1.getRowDimension()).minus(g1)).arrayRightDivide(p);
            dlogp = this.mergeMatrix(this.diag(v).times(z[0]).minus(this.diag(v1).times(z[1])), this.diag(v.minus(v1)).times(x));
            dl = this.sumMatrix(dlogp).transpose();
            w = v.arrayTimes(this.ones(g.getRowDimension()).minus(g.times(2.0)));
            w1 = v1.arrayTimes(this.ones(g1.getRowDimension()).minus(g1.times(2.0)));
            d2l = this.mergeMatrix(z[0], x).transpose().times(this.diag(w)).times(this.mergeMatrix(z[0], x)).minus(this.mergeMatrix(z[1], x).transpose().times(this.diag(w1)).times(this.mergeMatrix(z[1], x))).minus(dlogp.transpose().times(dlogp));
        }
        return new OrdinalLogisticRegressionModel(exampleSet.getAttributes().getLabel(), tb);
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeDouble("tolerance", "Tolerance of convergence.", 0.0, Double.POSITIVE_INFINITY, 1.0E-6));
        return types;
    }
}

