/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.kernel;

import edu.udo.cs.myRVM.ClassificationProblem;
import edu.udo.cs.myRVM.ConstructiveRegression;
import edu.udo.cs.myRVM.Kernel.KernelBasisFunction;
import edu.udo.cs.myRVM.Kernel.KernelCauchy;
import edu.udo.cs.myRVM.Kernel.KernelEpanechnikov;
import edu.udo.cs.myRVM.Kernel.KernelGaussianCombination;
import edu.udo.cs.myRVM.Kernel.KernelLaplace;
import edu.udo.cs.myRVM.Kernel.KernelMultiquadric;
import edu.udo.cs.myRVM.Kernel.KernelPoly;
import edu.udo.cs.myRVM.Kernel.KernelRadial;
import edu.udo.cs.myRVM.Kernel.KernelSigmoid;
import edu.udo.cs.myRVM.Parameter;
import edu.udo.cs.myRVM.RVMClassification;
import edu.udo.cs.myRVM.RVMRegression;
import edu.udo.cs.myRVM.RegressionProblem;
import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.Model;
import edu.udo.cs.yale.operator.Operator;
import edu.udo.cs.yale.operator.OperatorDescription;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.UserError;
import edu.udo.cs.yale.operator.learner.AbstractLearner;
import edu.udo.cs.yale.operator.learner.LearnerCapability;
import edu.udo.cs.yale.operator.learner.kernel.RVMModel;
import edu.udo.cs.yale.operator.parameter.ParameterType;
import edu.udo.cs.yale.operator.parameter.ParameterTypeCategory;
import edu.udo.cs.yale.operator.parameter.ParameterTypeDouble;
import edu.udo.cs.yale.operator.parameter.ParameterTypeInt;
import edu.udo.cs.yale.operator.parameter.ParameterTypeSingle;
import edu.udo.cs.yale.tools.LogService;
import java.util.Iterator;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class RVMLearner
extends AbstractLearner {
    public static final String[] RVM_TYPES = new String[]{"Regression-RVM", "Classification-RVM", "Constructive-Regression-RVM"};
    public static final String[] KERNEL_TYPES = new String[]{"rbf", "cauchy", "laplace", "poly", "sigmoid", "Epanechnikov", "gaussian combination", "multiquadric"};

    public RVMLearner(OperatorDescription description) {
        super(description);
    }

    @Override
    public boolean supportsCapability(LearnerCapability lc) {
        if (lc == LearnerCapability.NUMERICAL_ATTRIBUTES) {
            return true;
        }
        if (lc == LearnerCapability.BINOMINAL_CLASS) {
            return true;
        }
        if (lc == LearnerCapability.POLYNOMINAL_CLASS) {
            return true;
        }
        return lc == LearnerCapability.NUMERICAL_CLASS;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        LogService.logMessage("Creating RVM.", 2);
        Parameter parameter = new Parameter();
        int numExamples = exampleSet.size();
        int numBases = numExamples + 1;
        parameter.min_delta_log_alpha = this.getParameterAsDouble("min_delta_log_alpha");
        parameter.alpha_max = this.getParameterAsDouble("alpha_max");
        parameter.maxIterations = this.getParameterAsInt("max_iteration");
        LogService.logMessage("=> Creating input / output vectors.", 2);
        double[][] x = new double[numExamples][exampleSet.getAttributes().size()];
        double[][] t = new double[numExamples][1];
        Iterator reader = exampleSet.iterator();
        int k = 0;
        while (reader.hasNext()) {
            double[] targetVector = new double[1];
            Example e = (Example)reader.next();
            targetVector[0] = e.getLabel();
            x[k] = RVMModel.makeInputVector(e);
            t[k] = targetVector;
            ++k;
        }
        Attribute label = exampleSet.getAttributes().getLabel();
        parameter.initAlpha = Math.pow(1.0 / (double)numExamples, 2.0);
        parameter.initSigma = 0.1;
        LogService.logMessage("=> Creating kernel basis functions [" + KERNEL_TYPES[this.getParameterAsInt("kernel_type")] + "].", 2);
        KernelBasisFunction[] kernels = this.createKernels(x, numBases);
        String RVMType2 = RVM_TYPES[this.getParameterAsInt("rvm_type")];
        edu.udo.cs.myRVM.Model model = null;
        if (label.isNominal()) {
            if (label.getMapping().size() != 2) {
                throw new UserError((Operator)this, 114, this.getName(), (Object)label.getName());
            }
            int[] c = new int[numExamples];
            k = 0;
            while (k < numExamples) {
                c[k] = new Double(t[k][0]).intValue();
                ++k;
            }
            ClassificationProblem problem = new ClassificationProblem(x, c, kernels);
            if (!RVMType2.equals("Classification-RVM")) throw new UserError((Operator)this, 207, new Object[]{RVMType2, "rvm_type", "only Classification-RVM can be used for the given two class classification problem"});
            RVMClassification RVM = new RVMClassification(problem, parameter);
            try {
                model = RVM.learn();
                return new RVMModel(exampleSet.getAttributes().getLabel(), model);
            }
            catch (ArrayIndexOutOfBoundsException e) {
                throw new UserError(this, 924);
            }
        } else {
            RegressionProblem problem = new RegressionProblem(x, t, kernels);
            if (RVMType2.equals("Regression-RVM")) {
                RVMRegression RVM = new RVMRegression(problem, parameter);
                model = RVM.learn();
                return new RVMModel(exampleSet.getAttributes().getLabel(), model);
            } else {
                if (!RVMType2.equals("Constructive-Regression-RVM")) throw new UserError((Operator)this, 207, new Object[]{RVMType2, "rvm_type", "only one of the regression types can be used for the given regression problem"});
                ConstructiveRegression RVM = new ConstructiveRegression(problem, parameter);
                model = RVM.learn();
            }
        }
        return new RVMModel(exampleSet.getAttributes().getLabel(), model);
    }

    public KernelBasisFunction[] createKernels(double[][] x, int numKernels) throws OperatorException {
        KernelBasisFunction[] kernels = new KernelBasisFunction[numKernels];
        KernelBasisFunction kernel = null;
        double lengthScale = this.getParameterAsDouble("kernel_lengthscale");
        double bias = this.getParameterAsDouble("kernel_bias");
        double degree = this.getParameterAsDouble("kernel_degree");
        double a = this.getParameterAsDouble("kernel_a");
        double b = this.getParameterAsDouble("kernel_b");
        double sigma1 = this.getParameterAsDouble("kernel_sigma1");
        double sigma2 = this.getParameterAsDouble("kernel_sigma2");
        double sigma3 = this.getParameterAsDouble("kernel_sigma3");
        double shift = this.getParameterAsDouble("kernel_shift");
        int j = 0;
        while (j < numKernels - 1) {
            double[] input = x[j];
            switch (this.getParameterAsInt("kernel_type")) {
                case 0: {
                    kernel = new KernelBasisFunction(new KernelRadial(lengthScale), input);
                    break;
                }
                case 1: {
                    kernel = new KernelBasisFunction(new KernelCauchy(lengthScale), input);
                    break;
                }
                case 2: {
                    kernel = new KernelBasisFunction(new KernelLaplace(lengthScale), input);
                    break;
                }
                case 3: {
                    kernel = new KernelBasisFunction(new KernelPoly(lengthScale, bias, degree), input);
                    break;
                }
                case 4: {
                    kernel = new KernelBasisFunction(new KernelSigmoid(a, b), input);
                    break;
                }
                case 5: {
                    kernel = new KernelBasisFunction(new KernelEpanechnikov(sigma1, degree), input);
                    break;
                }
                case 6: {
                    kernel = new KernelBasisFunction(new KernelGaussianCombination(sigma1, sigma2, sigma3), input);
                    break;
                }
                case 7: {
                    kernel = new KernelBasisFunction(new KernelMultiquadric(sigma1, shift), input);
                    break;
                }
                default: {
                    kernel = new KernelBasisFunction(new KernelRadial(lengthScale), input);
                }
            }
            kernels[j + 1] = kernel;
            ++j;
        }
        return kernels;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterTypeSingle type = new ParameterTypeCategory("rvm_type", "Regression RVM", RVM_TYPES, 0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeCategory("kernel_type", "The type of the kernel functions.", KERNEL_TYPES, 0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeInt("max_iteration", "The maximum number of iterations used.", 1, Integer.MAX_VALUE, 100);
        types.add(type);
        type = new ParameterTypeDouble("min_delta_log_alpha", "Abort iteration if largest log alpha change is smaller than this", 0.0, Double.POSITIVE_INFINITY, 0.001);
        types.add(type);
        type = new ParameterTypeDouble("alpha_max", "Prune basis function if its alpha is bigger than this", 0.0, Double.POSITIVE_INFINITY, 1.0E12);
        types.add(type);
        type = new ParameterTypeDouble("kernel_lengthscale", "The lengthscale used in all kernels.", 0.0, Double.POSITIVE_INFINITY, 3.0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble("kernel_degree", "The degree used in the poly kernel.", 0.0, Double.POSITIVE_INFINITY, 2.0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble("kernel_bias", "The bias used in the poly kernel.", 0.0, Double.POSITIVE_INFINITY, 1.0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble("kernel_sigma1", "The SVM kernel parameter sigma1 (Epanechnikov, Gaussian Combination, Multiquadric).", 0.0, Double.POSITIVE_INFINITY, 1.0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble("kernel_sigma2", "The SVM kernel parameter sigma2 (Gaussian Combination).", 0.0, Double.POSITIVE_INFINITY, 0.0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble("kernel_sigma3", "The SVM kernel parameter sigma3 (Gaussian Combination).", 0.0, Double.POSITIVE_INFINITY, 2.0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble("kernel_shift", "The SVM kernel parameter shift (polynomial, Multiquadric).", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble("kernel_a", "The SVM kernel parameter a (neural).", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble("kernel_b", "The SVM kernel parameter b (neural).", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0.0);
        type.setExpert(false);
        types.add(type);
        return types;
    }
}

