/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.functions;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.FastExample2SparseTransform;
import com.rapidminer.example.Tools;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.functions.FastLargeMargin;
import liblinear.FeatureNode;
import liblinear.Linear;
import liblinear.Model;

public class FastMarginModel
extends PredictionModel {
    private static final long serialVersionUID = 7701199447666181333L;
    private Model linearModel;
    private boolean useBias;
    private String[] attributeConstructions;

    public FastMarginModel(ExampleSet headerSet, Model linearModel, boolean useBias) {
        super(headerSet);
        this.linearModel = linearModel;
        this.useBias = useBias;
        this.attributeConstructions = Tools.getRegularAttributeConstructions(headerSet);
    }

    @Override
    public String getName() {
        return "Fast Linear Classification";
    }

    @Override
    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
        FastExample2SparseTransform ripper = new FastExample2SparseTransform(exampleSet);
        Attribute label = this.getLabel();
        Attribute[] confidenceAttributes = null;
        if (label.isNominal() && label.getMapping().size() >= 2) {
            confidenceAttributes = new Attribute[this.linearModel.label.length];
            for (int j = 0; j < this.linearModel.label.length; ++j) {
                String labelName = label.getMapping().mapIndex(this.linearModel.label[j]);
                confidenceAttributes[j] = exampleSet.getAttributes().getSpecial("confidence_" + labelName);
            }
        }
        for (Example e : exampleSet) {
            FeatureNode[] currentNodes = FastLargeMargin.makeNodes(e, ripper, this.useBias);
            double predictedClass = Linear.predict(this.linearModel, currentNodes);
            e.setValue(predictedLabel, predictedClass);
            if (label.getMapping().size() != 2) continue;
            double[] functionValues = new double[this.linearModel.nr_class];
            Linear.predictValues(this.linearModel, currentNodes, functionValues);
            double prediction = functionValues[0];
            if (confidenceAttributes == null || confidenceAttributes.length <= 0) continue;
            e.setValue(confidenceAttributes[0], 1.0 / (1.0 + Math.exp(-prediction)));
            if (confidenceAttributes.length <= 1) continue;
            e.setValue(confidenceAttributes[1], 1.0 / (1.0 + Math.exp(prediction)));
        }
        return exampleSet;
    }

    @Override
    public String toString() {
        StringBuffer result = new StringBuffer();
        boolean first = true;
        for (int i = 0; i < this.attributeConstructions.length; ++i) {
            result.append(this.getCoefficientString(this.linearModel.w[i], first) + " * " + this.attributeConstructions[i] + com.rapidminer.tools.Tools.getLineSeparator());
            first = false;
        }
        if (this.useBias) {
            result.append(this.getCoefficientString(this.linearModel.w[this.linearModel.w.length - 1], first));
        }
        return result.toString();
    }

    private String getCoefficientString(double coefficient, boolean first) {
        if (!first) {
            if (coefficient >= 0.0) {
                return "+ " + com.rapidminer.tools.Tools.formatNumber(Math.abs(coefficient));
            }
            return "- " + com.rapidminer.tools.Tools.formatNumber(Math.abs(coefficient));
        }
        if (coefficient >= 0.0) {
            return "  " + com.rapidminer.tools.Tools.formatNumber(Math.abs(coefficient));
        }
        return "- " + com.rapidminer.tools.Tools.formatNumber(Math.abs(coefficient));
    }
}

