/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.functions;

import Jama.Matrix;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeRole;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.annotation.ResourceConsumptionEstimator;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.functions.VectorRegressionModel;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.tools.OperatorResourceConsumptionHandler;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public class VectorLinearRegression
extends AbstractLearner {
    public static final String PARAMETER_USE_BIAS = "use_bias";
    public static final String PARAMETER_RIDGE = "ridge";

    public VectorLinearRegression(OperatorDescription description) {
        super(description);
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        boolean useBias = this.getParameterAsBoolean(PARAMETER_USE_BIAS);
        double ridge = this.getParameterAsDouble(PARAMETER_RIDGE);
        LinkedList<Attribute> labels = new LinkedList<Attribute>();
        Iterator<AttributeRole> roleIterator = exampleSet.getAttributes().allAttributeRoles();
        while (roleIterator.hasNext()) {
            AttributeRole role = roleIterator.next();
            if (role.getSpecialName() == null || !role.getSpecialName().startsWith("label")) continue;
            labels.add(role.getAttribute());
        }
        int biasOffset = useBias ? 1 : 0;
        int width = exampleSet.getAttributes().size() + 1;
        Matrix x = new Matrix(exampleSet.size(), width);
        Matrix y = new Matrix(exampleSet.size(), labels.size());
        int j = 0;
        for (Example example : exampleSet) {
            if (useBias) {
                x.set(j, 0, 1.0);
            }
            int i = biasOffset;
            for (Attribute attribute : exampleSet.getAttributes()) {
                x.set(j, i, example.getValue(attribute));
                ++i;
            }
            int k = 0;
            for (Attribute label : labels) {
                y.set(j, k, example.getValue(label));
                ++k;
            }
            ++j;
        }
        int numberOfColumns = x.getColumnDimension();
        Matrix xTransposed = x.transpose();
        Matrix result = null;
        boolean finished = false;
        while (!finished) {
            Matrix xTx = xTransposed.times(x);
            for (int i = 0; i < numberOfColumns; ++i) {
                xTx.set(i, i, xTx.get(i, i) + ridge);
            }
            Matrix xTy = xTransposed.times(y);
            try {
                result = xTx.solve(xTy);
                finished = true;
            }
            catch (Exception ex) {
                ridge *= 10.0;
                finished = false;
            }
        }
        String[] labelNames = new String[labels.size()];
        for (int i = 0; i < labels.size(); ++i) {
            labelNames[i] = ((Attribute)labels.get(i)).getName();
        }
        return new VectorRegressionModel(exampleSet, labelNames, result, useBias);
    }

    @Override
    public Class<? extends PredictionModel> getModelClass() {
        return VectorRegressionModel.class;
    }

    @Override
    public boolean supportsCapability(OperatorCapability lc) {
        if (lc.equals((Object)OperatorCapability.NUMERICAL_ATTRIBUTES)) {
            return true;
        }
        if (lc.equals((Object)OperatorCapability.NUMERICAL_LABEL)) {
            return true;
        }
        if (lc == OperatorCapability.WEIGHTED_EXAMPLES) {
            return false;
        }
        return false;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeBoolean(PARAMETER_USE_BIAS, "Indicates if an intercept value should be calculated.", true));
        types.add(new ParameterTypeDouble(PARAMETER_RIDGE, "The ridge parameter.", 0.0, Double.POSITIVE_INFINITY, 1.0E-8));
        return types;
    }

    @Override
    public ResourceConsumptionEstimator getResourceConsumptionEstimator() {
        return OperatorResourceConsumptionHandler.getResourceConsumptionEstimator(this.getExampleSetInputPort(), VectorLinearRegression.class, null);
    }
}

