/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.meta.MetaModel;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public class AdaBoostModel
extends PredictionModel
implements MetaModel {
    private static final long serialVersionUID = -4145136493164813582L;
    private List<Model> models;
    private List<Double> weights;
    private int maxModelNumber = -1;
    private static final String MAX_MODEL_NUMBER = "iteration";

    public AdaBoostModel(ExampleSet exampleSet, List<Model> models, List<Double> weights) {
        super(exampleSet);
        this.models = models;
        this.weights = weights;
        for (double i : weights) {
            if (!Double.isNaN(i) && !Double.isInfinite(i)) continue;
            this.logWarning("Found model weight " + i);
        }
    }

    public void setParameter(String name, String value) throws OperatorException {
        if (name.equalsIgnoreCase(MAX_MODEL_NUMBER)) {
            try {
                this.maxModelNumber = Integer.parseInt(value);
                return;
            }
            catch (NumberFormatException numberFormatException) {
                // empty catch block
            }
        }
        super.setParameter(name, value);
    }

    public void setMaxModelNumber(int numModels) {
        this.maxModelNumber = numModels;
    }

    @Override
    public String toString() {
        StringBuffer result = new StringBuffer(super.toString() + com.rapidminer.tools.Tools.getLineSeparator() + "Number of inner models: " + this.getNumberOfModels() + com.rapidminer.tools.Tools.getLineSeparators(2));
        for (int i = 0; i < this.getNumberOfModels(); ++i) {
            Model model = this.getModel(i);
            result.append((i > 0 ? com.rapidminer.tools.Tools.getLineSeparator() : "") + "Embedded model #" + i + " (weight: " + com.rapidminer.tools.Tools.formatNumber(this.getWeightForModel(i)) + "): " + com.rapidminer.tools.Tools.getLineSeparator() + model.toResultString());
        }
        return result.toString();
    }

    public int getNumberOfModels() {
        if (this.maxModelNumber >= 0) {
            return Math.min(this.maxModelNumber, this.models.size());
        }
        return this.models.size();
    }

    private double getWeightForModel(int modelNr) {
        return this.weights.get(modelNr);
    }

    public Model getModel(int index) {
        return this.models.get(index);
    }

    @Override
    public ExampleSet performPrediction(ExampleSet origExampleSet, Attribute predictedLabel) throws OperatorException {
        String attributePrefix = "AdaBoostModelPrediction";
        int numLabels = predictedLabel.getMapping().size();
        Attribute[] specialAttributes = new Attribute[numLabels];
        for (int i = 0; i < numLabels; ++i) {
            specialAttributes[i] = Tools.createSpecialAttribute(origExampleSet, "AdaBoostModelPrediction" + i, 2);
        }
        for (Example example : origExampleSet) {
            for (int i = 0; i < specialAttributes.length; ++i) {
                example.setValue(specialAttributes[i], 0.0);
            }
        }
        Iterator reader = origExampleSet.iterator();
        for (int modelNr = 0; modelNr < this.getNumberOfModels(); ++modelNr) {
            Model model = this.getModel(modelNr);
            ExampleSet exampleSet = (ExampleSet)origExampleSet.clone();
            exampleSet = model.apply(exampleSet);
            this.updateEstimates(exampleSet, modelNr, specialAttributes);
            PredictionModel.removePredictedLabel(exampleSet);
        }
        this.evaluateSpecialAttributes(origExampleSet, specialAttributes);
        for (int i = 0; i < numLabels; ++i) {
            origExampleSet.getAttributes().remove(specialAttributes[i]);
            origExampleSet.getExampleTable().removeAttribute(specialAttributes[i]);
        }
        return origExampleSet;
    }

    private void updateEstimates(ExampleSet exampleSet, int modelNr, Attribute[] specialAttributes) {
        for (Example example : exampleSet) {
            int predicted;
            double oldValue = example.getValue(specialAttributes[predicted = (int)example.getPredictedLabel()]);
            if (Double.isNaN(oldValue)) {
                this.logWarning("Found NaN confidence as intermediate prediction.");
                oldValue = 0.0;
            }
            if (Double.isInfinite(oldValue)) continue;
            example.setValue(specialAttributes[predicted], oldValue + this.getWeightForModel(modelNr));
        }
    }

    private void evaluateSpecialAttributes(ExampleSet exampleSet, Attribute[] specialAttributes) {
        Attribute label = exampleSet.getAttributes().getLabel();
        Attribute predictedLabel = exampleSet.getAttributes().getPredictedLabel();
        for (Example example : exampleSet) {
            int n;
            double sum = 0.0;
            double[] confidences = new double[specialAttributes.length];
            double bestConf = -1.0;
            int bestLabel = 0;
            for (n = 0; n < confidences.length; ++n) {
                confidences[n] = example.getValue(specialAttributes[n]);
                if (!(confidences[n] > bestConf)) continue;
                bestConf = confidences[n];
                bestLabel = n;
            }
            example.setValue(predictedLabel, label.getMapping().mapString(this.getLabel().getMapping().mapIndex(bestLabel)));
            for (n = 0; n < confidences.length; ++n) {
                confidences[n] = Math.exp(confidences[n] - bestConf);
                sum += confidences[n];
            }
            if (Double.isInfinite(sum) || Double.isNaN(sum)) {
                int best = (int)example.getPredictedLabel();
                for (int k = 0; k < confidences.length; ++k) {
                    confidences[k] = 0.0;
                }
                confidences[best] = 1.0;
                continue;
            }
            for (int k = 0; k < confidences.length; ++k) {
                int n2 = k;
                confidences[n2] = confidences[n2] / sum;
                example.setConfidence(predictedLabel.getMapping().mapIndex(k), confidences[k]);
            }
        }
    }

    public List<Model> getModels() {
        return this.models;
    }

    @Override
    public List<String> getModelNames() {
        LinkedList<String> names = new LinkedList<String>();
        for (int i = 0; i < this.getNumberOfModels(); ++i) {
            names.add("Model " + (i + 1) + " [w = " + com.rapidminer.tools.Tools.formatNumber(this.getWeightForModel(i)) + "]");
        }
        return names;
    }
}

