/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.functions;

import Jama.Matrix;
import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.Model;
import edu.udo.cs.yale.operator.OperatorDescription;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.learner.AbstractLearner;
import edu.udo.cs.yale.operator.learner.LearnerCapability;
import edu.udo.cs.yale.operator.learner.functions.LogisticRegressionModel;
import edu.udo.cs.yale.operator.parameter.ParameterType;
import edu.udo.cs.yale.operator.parameter.ParameterTypeBoolean;
import edu.udo.cs.yale.operator.parameter.ParameterTypeDouble;
import java.util.Iterator;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class LogisticRegression
extends AbstractLearner {
    private double epsBound = 1.0E-8;
    private boolean startWithOLS = false;

    public LogisticRegression(OperatorDescription description) {
        super(description);
    }

    @Override
    public boolean supportsCapability(LearnerCapability lc) {
        if (lc == LearnerCapability.POLYNOMINAL_ATTRIBUTES) {
            return false;
        }
        if (lc == LearnerCapability.BINOMINAL_ATTRIBUTES) {
            return false;
        }
        if (lc == LearnerCapability.NUMERICAL_ATTRIBUTES) {
            return true;
        }
        if (lc == LearnerCapability.POLYNOMINAL_CLASS) {
            return false;
        }
        if (lc == LearnerCapability.BINOMINAL_CLASS) {
            return false;
        }
        return lc == LearnerCapability.NUMERICAL_CLASS;
    }

    private Matrix buildDesignMatrix(ExampleSet exampleSet) {
        Matrix x = new Matrix(exampleSet.size(), exampleSet.getAttributes().size());
        Iterator iterator = exampleSet.iterator();
        int i = 0;
        while (iterator.hasNext()) {
            Example example = (Example)iterator.next();
            int j = 0;
            for (Attribute attribute : example.getAttributes()) {
                double value = example.getValue(attribute);
                x.set(i, j, value);
                ++j;
            }
            ++i;
        }
        return x;
    }

    private Matrix getResponseVector(ExampleSet exampleSet) {
        int numberOfRows = exampleSet.size();
        Matrix y = new Matrix(numberOfRows, 1);
        Iterator iterator = exampleSet.iterator();
        int i = 0;
        while (iterator.hasNext()) {
            Example example = (Example)iterator.next();
            double value = example.getLabel();
            y.set(i, 0, value);
            ++i;
        }
        return y;
    }

    private Matrix getRowVector(Matrix x, int i) {
        return x.getMatrix(i, i, 0, x.getColumnDimension() - 1);
    }

    private double getCumulatedSum(Matrix x) {
        double sum = 0.0;
        int i = 0;
        while (i < x.getColumnDimension()) {
            sum += x.get(0, i);
            ++i;
        }
        return sum;
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        Matrix d;
        this.startWithOLS = this.getParameterAsBoolean("start_with_ols_estimate");
        this.epsBound = this.getParameterAsDouble("eps_bound");
        Matrix x = this.buildDesignMatrix(exampleSet);
        Matrix y = this.getResponseVector(exampleSet);
        int m = x.getRowDimension();
        int n = x.getColumnDimension();
        Matrix a = null;
        Matrix b = null;
        Matrix pi = new Matrix(m, 1);
        Matrix betaOld = null;
        Matrix beta = new Matrix(1, n);
        if (this.startWithOLS) {
            a = x.transpose().times(y);
            b = x.transpose().times(x);
            beta = b.inverse().times(a).transpose();
        }
        double diff = 1.0;
        int iteration = 0;
        do {
            ++iteration;
            betaOld = beta;
            d = new Matrix(m, m);
            int i = 0;
            while (i < m) {
                Matrix linearPredictor = this.getRowVector(x, i).arrayTimes(beta);
                double e = Math.exp(this.getCumulatedSum(linearPredictor));
                pi.set(i, 0, e / (1.0 + e));
                d.set(i, i, pi.get(i, 0) * (1.0 - pi.get(i, 0)));
                ++i;
            }
            a = x.transpose().times(y.minus(pi));
        } while ((diff = (beta = betaOld.plus((b = x.transpose().times(d).times(x)).inverse().times(a).transpose())).minus(betaOld).normF() / betaOld.normF()) > this.epsBound);
        return new LogisticRegressionModel(exampleSet.getAttributes().getLabel(), beta.transpose());
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeBoolean("start_with_ols_estimate", "Start with OLS estimate instead of null.", false));
        types.add(new ParameterTypeDouble("eps_bound", "Bound for convergance criterion", 0.0, Double.POSITIVE_INFINITY, 1.0E-8));
        return types;
    }
}

