/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.functions.kernel;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeRole;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.FastExample2SparseTransform;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.FormulaProvider;
import com.rapidminer.operator.learner.functions.kernel.KernelModel;
import com.rapidminer.operator.learner.functions.kernel.LibSVMLearner;
import com.rapidminer.operator.learner.functions.kernel.SupportVector;
import com.rapidminer.tools.Tools;
import libsvm.Svm;
import libsvm.svm_model;
import libsvm.svm_node;

public class LibSVMModel
extends KernelModel
implements FormulaProvider {
    private static final long serialVersionUID = -2654603017217487365L;
    private svm_model model;
    private int numberOfAttributes;
    private boolean confidenceForMultiClass = true;

    public LibSVMModel(ExampleSet exampleSet, svm_model model, int numberOfAttributes, boolean confidenceForMultiClass) {
        super(exampleSet);
        this.model = model;
        this.numberOfAttributes = numberOfAttributes;
        this.confidenceForMultiClass = confidenceForMultiClass;
    }

    @Override
    public boolean isClassificationModel() {
        return this.getLabel().isNominal();
    }

    @Override
    public double getAlpha(int index) {
        return this.model.sv_coef[0][index];
    }

    @Override
    public String getId(int index) {
        return null;
    }

    @Override
    public int getNumberOfSupportVectors() {
        return this.model.SV.length;
    }

    @Override
    public int getNumberOfAttributes() {
        return this.numberOfAttributes;
    }

    @Override
    public double getBias() {
        if (this.model.rho.length > 0) {
            return this.model.rho[0];
        }
        return 0.0;
    }

    @Override
    public SupportVector getSupportVector(int index) {
        svm_node[] nodes = this.model.SV[index];
        double[] x = new double[this.getNumberOfAttributes()];
        for (int i = 0; i < nodes.length; ++i) {
            x[nodes[i].index] = nodes[i].value;
        }
        return new SupportVector(x, this.getRegressionLabel(index), Math.abs(this.getAlpha(index)));
    }

    @Override
    public double getAttributeValue(int exampleIndex, int attributeIndex) {
        double[] dense = new double[this.numberOfAttributes];
        svm_node[] node = this.model.SV[exampleIndex];
        for (int i = 0; i < node.length; ++i) {
            dense[node[i].index] = node[i].value;
        }
        return dense[attributeIndex];
    }

    @Override
    public String getClassificationLabel(int index) {
        double functionValue = this.getRegressionLabel(index);
        if (!Double.isNaN(functionValue)) {
            return this.getLabel().getMapping().mapIndex((int)functionValue);
        }
        return "?";
    }

    @Override
    public double getRegressionLabel(int index) {
        if (this.model.labelValues != null) {
            return this.model.labelValues[index];
        }
        return Double.NaN;
    }

    @Override
    public double getFunctionValue(int index) {
        if (this.getLabel().isNominal()) {
            double[] classProbs = new double[this.getLabel().getMapping().size()];
            Svm.svm_predict_probability(this.model, this.model.SV[index], classProbs);
            return classProbs[0];
        }
        return Svm.svm_predict(this.model, this.model.SV[index]);
    }

    @Override
    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws UserError {
        FastExample2SparseTransform ripper = new FastExample2SparseTransform(exampleSet);
        Attribute label = this.getLabel();
        if (this.model.param.svm_type == 2) {
            predictedLabel.getMapping().getValues().clear();
            predictedLabel.getMapping().getValues().add("outside");
            predictedLabel.getMapping().getValues().add("inside");
            Attribute confidenceAttribute = AttributeFactory.createAttribute("confidence(inside)", 4);
            exampleSet.getExampleTable().addAttribute(confidenceAttribute);
            AttributeRole confidenceRole = new AttributeRole(confidenceAttribute);
            confidenceRole.setSpecial("confidence_inside");
            exampleSet.getAttributes().add(confidenceRole);
            int counter = 0;
            double[] allConfidences = new double[exampleSet.size()];
            int[] allLabels = new int[exampleSet.size()];
            double maxConfidence = Double.NEGATIVE_INFINITY;
            double minConfidence = Double.POSITIVE_INFINITY;
            for (Example example : exampleSet) {
                svm_node[] currentNodes = LibSVMLearner.makeNodes(example, ripper);
                double[] prob = new double[1];
                Svm.svm_predict_values(this.model, currentNodes, prob);
                allLabels[counter] = prob[0] >= 0.0 ? 1 : 0;
                allConfidences[counter] = prob[0];
                minConfidence = Math.min(minConfidence, prob[0]);
                maxConfidence = Math.max(maxConfidence, prob[0]);
                ++counter;
            }
            counter = 0;
            for (Example example : exampleSet) {
                double confidence = allConfidences[counter];
                example.setValue(predictedLabel, allLabels[counter]);
                example.setValue(confidenceAttribute, confidence);
                ++counter;
            }
        } else {
            Attribute[] confidenceAttributes = null;
            if (label.isNominal() && label.getMapping().size() >= 2) {
                confidenceAttributes = new Attribute[this.model.label.length];
                for (int j = 0; j < this.model.label.length; ++j) {
                    String labelName = label.getMapping().mapIndex(this.model.label[j]);
                    confidenceAttributes[j] = exampleSet.getAttributes().getSpecial("confidence_" + labelName);
                }
            }
            for (Example example : exampleSet) {
                if (label.isNominal()) {
                    svm_node[] currentNodes = LibSVMLearner.makeNodes(example, ripper);
                    if (this.model.probA != null && this.model.probB != null) {
                        double predictedClass;
                        double[] classProbs = new double[this.model.nr_class];
                        int nr_class = this.model.nr_class;
                        double[] dec_values = new double[nr_class * (nr_class - 1) / 2];
                        Svm.svm_predict_values(this.model, currentNodes, dec_values);
                        double min_prob = 1.0E-7;
                        double[][] pairwise_prob = new double[nr_class][nr_class];
                        int k = 0;
                        for (int a = 0; a < nr_class; ++a) {
                            for (int j = a + 1; j < nr_class; ++j) {
                                pairwise_prob[a][j] = Math.min(Math.max(Svm.sigmoid_predict(dec_values[k], this.model.probA[k], this.model.probB[k]), min_prob), 1.0 - min_prob);
                                pairwise_prob[j][a] = 1.0 - pairwise_prob[a][j];
                                ++k;
                            }
                        }
                        Svm.multiclass_probability(nr_class, pairwise_prob, classProbs);
                        for (k = 0; k < nr_class; ++k) {
                            example.setValue(confidenceAttributes[k], classProbs[k]);
                        }
                        if (this.confidenceForMultiClass) {
                            predictedClass = Svm.svm_predict_probability(this.model, currentNodes, classProbs);
                            example.setValue(predictedLabel, predictedClass);
                            continue;
                        }
                        predictedClass = Svm.svm_predict(this.model, currentNodes);
                        example.setValue(predictedLabel, predictedClass);
                        continue;
                    }
                    double predictedClass = Svm.svm_predict(this.model, currentNodes);
                    example.setValue(predictedLabel, predictedClass);
                    if (label.getMapping().size() == 2) {
                        double[] functionValues = new double[this.model.nr_class];
                        Svm.svm_predict_values(this.model, currentNodes, functionValues);
                        double prediction = functionValues[0];
                        if (confidenceAttributes == null || confidenceAttributes.length <= 0) continue;
                        example.setValue(confidenceAttributes[0], 1.0 / (1.0 + Math.exp(-prediction)));
                        if (confidenceAttributes.length <= 1) continue;
                        example.setValue(confidenceAttributes[1], 1.0 / (1.0 + Math.exp(prediction)));
                        continue;
                    }
                    example.setConfidence(this.getLabel().getMapping().mapIndex((int)predictedClass), 1.0);
                    continue;
                }
                example.setValue(predictedLabel, Svm.svm_predict(this.model, LibSVMLearner.makeNodes(example, ripper)));
            }
        }
        return exampleSet;
    }

    @Override
    protected boolean supportsConfidences(Attribute label) {
        return super.supportsConfidences(label) && this.model.param.svm_type != 2;
    }

    @Override
    public String toString() {
        StringBuffer result = new StringBuffer(super.toString() + Tools.getLineSeparator());
        result.append("number of classes: " + this.model.nr_class + Tools.getLineSeparator());
        if (this.getLabel().isNominal() && this.getLabel().getMapping().size() >= 2 && this.model.nSV != null) {
            for (int i = 0; i < this.model.nSV.length; ++i) {
                result.append("number of support vectors for class " + this.getLabel().getMapping().mapIndex(this.model.label[i]) + ": " + this.model.nSV[i] + Tools.getLineSeparator());
            }
        } else {
            result.append("number of support vectors: " + this.model.l + Tools.getLineSeparator());
        }
        return result.toString();
    }

    @Override
    public String getFormula() {
        StringBuffer result = new StringBuffer();
        int kernelType = this.model.param.kernel_type;
        if (kernelType == 4) {
            return "Precomputed kernel, no formula possible.";
        }
        if (kernelType == 2) {
            return "RBF kernel, no formula possible.";
        }
        boolean first = true;
        for (int i = 0; i < this.getNumberOfSupportVectors(); ++i) {
            double alpha;
            SupportVector sv = this.getSupportVector(i);
            if (sv == null || Tools.isZero(alpha = sv.getAlpha())) continue;
            result.append(Tools.getLineSeparator());
            double[] x = sv.getX();
            double y = sv.getY();
            double factor = y * alpha;
            if (factor < 0.0) {
                if (first) {
                    result.append("- " + Math.abs(factor));
                } else {
                    result.append("- " + Math.abs(factor));
                }
            } else if (first) {
                result.append("  " + factor);
            } else {
                result.append("+ " + factor);
            }
            result.append(" * (" + this.getDistanceFormula(x, this.getAttributeConstructions()) + ")");
            first = false;
        }
        double bias = this.getBias();
        if (!Tools.isZero(bias)) {
            result.append(Tools.getLineSeparator());
            if (bias < 0.0) {
                if (first) {
                    result.append("- " + Math.abs(bias));
                } else {
                    result.append("- " + Math.abs(bias));
                }
            } else if (first) {
                result.append(bias);
            } else {
                result.append("+ " + bias);
            }
        }
        return result.toString();
    }

    private String getDistanceFormula(double[] x, String[] attributeConstructions) {
        int kernelType = this.model.param.kernel_type;
        switch (kernelType) {
            case 0: {
                StringBuffer result = new StringBuffer();
                boolean first = true;
                for (int i = 0; i < x.length; ++i) {
                    double value = x[i];
                    if (Tools.isZero(value)) continue;
                    if (value < 0.0) {
                        if (first) {
                            result.append("-" + Math.abs(value) + " * " + attributeConstructions[i]);
                        } else {
                            result.append(" - " + Math.abs(value) + " * " + attributeConstructions[i]);
                        }
                    } else if (first) {
                        result.append(value + " * " + attributeConstructions[i]);
                    } else {
                        result.append(" + " + value + " * " + attributeConstructions[i]);
                    }
                    first = false;
                }
                return result.toString();
            }
            case 1: {
                StringBuffer dotResult = new StringBuffer();
                boolean first = true;
                for (int i = 0; i < x.length; ++i) {
                    double value = x[i];
                    if (Tools.isZero(value)) continue;
                    if (value < 0.0) {
                        if (first) {
                            dotResult.append("-" + Math.abs(value) + " * " + attributeConstructions[i]);
                        } else {
                            dotResult.append(" - " + Math.abs(value) + " * " + attributeConstructions[i]);
                        }
                    } else if (first) {
                        dotResult.append(value + " * " + attributeConstructions[i]);
                    } else {
                        dotResult.append(" + " + value + " * " + attributeConstructions[i]);
                    }
                    first = false;
                }
                return "pow((" + this.model.param.gamma + " * (" + dotResult.toString() + ") + " + this.model.param.coef0 + "), " + this.model.param.degree + ")";
            }
            case 3: {
                StringBuffer dotResult = new StringBuffer();
                boolean first = true;
                for (int i = 0; i < x.length; ++i) {
                    double value = x[i];
                    if (Tools.isZero(value)) continue;
                    if (value < 0.0) {
                        if (first) {
                            dotResult.append("-" + Math.abs(value) + " * " + attributeConstructions[i]);
                        } else {
                            dotResult.append(" - " + Math.abs(value) + " * " + attributeConstructions[i]);
                        }
                    } else if (first) {
                        dotResult.append(value + " * " + attributeConstructions[i]);
                    } else {
                        dotResult.append(" + " + value + " * " + attributeConstructions[i]);
                    }
                    first = false;
                }
                return "tanh(" + this.model.param.gamma + " * (" + dotResult.toString() + ") + " + this.model.param.coef0 + ")";
            }
        }
        return "";
    }
}

