/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.meta.BayBoostBaseModelInfo;
import com.rapidminer.operator.learner.meta.ContingencyMatrix;
import com.rapidminer.operator.learner.meta.MetaModel;
import com.rapidminer.tools.LogService;
import java.util.LinkedList;
import java.util.List;

public class BayBoostModel
extends PredictionModel
implements MetaModel {
    private static final long serialVersionUID = 5821921049035718838L;
    private final List<BayBoostBaseModelInfo> modelInfo;
    private final double[] priors;
    private int maxModelNumber = -1;
    private static final String MAX_MODEL_NUMBER = "iteration";
    private static final String CONV_TO_CRISP = "crisp";
    private double threshold = 0.5;

    public BayBoostModel(ExampleSet exampleSet, List<BayBoostBaseModelInfo> modelInfos, double[] priors) {
        super(exampleSet);
        this.modelInfo = modelInfos;
        this.priors = priors;
    }

    public BayBoostBaseModelInfo getBayBoostBaseModelInfo(int index) {
        return this.modelInfo.get(index);
    }

    public void setParameter(String name, String value) throws OperatorException {
        if (name.equalsIgnoreCase(MAX_MODEL_NUMBER)) {
            try {
                this.maxModelNumber = Integer.parseInt(value);
                return;
            }
            catch (NumberFormatException numberFormatException) {}
        } else if (name.equalsIgnoreCase(CONV_TO_CRISP)) {
            this.threshold = Double.parseDouble(value.trim());
            return;
        }
        super.setParameter(name, value);
    }

    public void setMaxModelNumber(int numModels) {
        this.maxModelNumber = numModels;
    }

    @Override
    public String toString() {
        StringBuffer result = new StringBuffer(super.toString() + com.rapidminer.tools.Tools.getLineSeparator() + "Number of inner models: " + this.getNumberOfModels() + com.rapidminer.tools.Tools.getLineSeparators(2));
        for (int i = 0; i < this.getNumberOfModels(); ++i) {
            Model model = this.getModel(i);
            result.append((i > 0 ? com.rapidminer.tools.Tools.getLineSeparator() : "") + "Embedded model #" + i + ":" + com.rapidminer.tools.Tools.getLineSeparator() + model.toResultString());
        }
        return result.toString();
    }

    public int getNumberOfModels() {
        if (this.maxModelNumber >= 0) {
            return Math.min(this.maxModelNumber, this.modelInfo.size());
        }
        return this.modelInfo.size();
    }

    private double[] getFactorsForModel(int modelNr, int predicted) {
        ContingencyMatrix cm = this.modelInfo.get(modelNr).getContingencyMatrix();
        return cm.getLiftRatiosForPrediction(predicted);
    }

    private double getPriorOfClass(int classIndex) {
        return this.priors[classIndex];
    }

    public double[] getPriors() {
        double[] result = new double[this.priors.length];
        System.arraycopy(this.priors, 0, result, 0, result.length);
        return result;
    }

    public Model getModel(int index) {
        return this.modelInfo.get(index).getModel();
    }

    public ContingencyMatrix getContingencyMatrix(int index) {
        return this.modelInfo.get(index).getContingencyMatrix();
    }

    @Override
    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
        Attribute[] specialAttributes = this.createSpecialAttributes(exampleSet);
        this.initIntermediateResultAttributes(exampleSet, specialAttributes);
        for (int i = 0; i < this.getNumberOfModels(); ++i) {
            Model model = this.getModel(i);
            ExampleSet clonedExampleSet = (ExampleSet)exampleSet.clone();
            clonedExampleSet = model.apply(clonedExampleSet);
            this.updateEstimates(clonedExampleSet, this.getContingencyMatrix(i), specialAttributes);
            PredictionModel.removePredictedLabel(clonedExampleSet);
        }
        for (Example example : exampleSet) {
            this.translateOddsIntoPredictions(example, specialAttributes, this.getTrainingHeader().getAttributes().getLabel());
        }
        this.cleanUpSpecialAttributes(exampleSet, specialAttributes);
        return exampleSet;
    }

    private Attribute[] createSpecialAttributes(ExampleSet exampleSet) throws OperatorException {
        String attributePrefix = "BayBoostModelPrediction";
        Attribute[] specialAttributes = new Attribute[this.getLabel().getMapping().size()];
        for (int i = 0; i < specialAttributes.length; ++i) {
            specialAttributes[i] = Tools.createSpecialAttribute(exampleSet, "BayBoostModelPrediction" + i, 2);
        }
        return specialAttributes;
    }

    private void cleanUpSpecialAttributes(ExampleSet exampleSet, Attribute[] specialAttributes) throws OperatorException {
        for (int i = 0; i < specialAttributes.length; ++i) {
            exampleSet.getAttributes().remove(specialAttributes[i]);
            exampleSet.getExampleTable().removeAttribute(specialAttributes[i]);
        }
    }

    private void initIntermediateResultAttributes(ExampleSet exampleSet, Attribute[] specAttrib) {
        double[] priorOdds = new double[this.priors.length];
        for (int i = 0; i < priorOdds.length; ++i) {
            priorOdds[i] = this.priors[i] == 1.0 ? Double.POSITIVE_INFINITY : this.priors[i] / (1.0 - this.priors[i]);
        }
        for (Example example : exampleSet) {
            for (int i = 0; i < specAttrib.length; ++i) {
                example.setValue(specAttrib[i], priorOdds[i]);
            }
        }
    }

    private void translateOddsIntoPredictions(Example example, Attribute[] specAttrib, Attribute trainingSetLabel) {
        String bestLabel;
        double probSum = 0.0;
        double[] classProb = new double[specAttrib.length];
        int bestIndex = 0;
        for (int n = 0; n < classProb.length; ++n) {
            double odds = example.getValue(specAttrib[n]);
            if (Double.isNaN(odds)) {
                this.logWarning("Found NaN odd ratio estimate.");
                classProb[n] = 1.0;
            } else {
                classProb[n] = Double.isInfinite(odds) ? 1.0 : odds / (1.0 + odds);
            }
            probSum += classProb[n];
            if (!(classProb[n] > classProb[bestIndex])) continue;
            bestIndex = n;
        }
        if (probSum != 1.0) {
            int k = 0;
            while (k < classProb.length) {
                int n = k++;
                classProb[n] = classProb[n] / probSum;
            }
        }
        if (this.getLabel().isNominal() && this.getLabel().getMapping().size() == 2 && this.threshold != 0.5) {
            int posIndex = this.getLabel().getMapping().getPositiveIndex();
            int negIndex = this.getLabel().getMapping().getNegativeIndex();
            this.threshold = this.threshold >= 0.0 && this.threshold <= 1.0 ? this.threshold : 0.5;
            bestLabel = this.getLabel().getMapping().mapIndex(classProb[posIndex] >= this.threshold ? posIndex : negIndex);
        } else {
            bestLabel = this.getLabel().getMapping().mapIndex(bestIndex);
        }
        example.setValue(example.getAttributes().getPredictedLabel(), trainingSetLabel.getMapping().mapString(bestLabel));
        for (int k = 0; k < classProb.length; ++k) {
            if (Double.isNaN(classProb[k]) || classProb[k] < 0.0 || classProb[k] > 1.0) {
                this.logWarning("Found illegal confidence value: " + classProb[k]);
            }
            example.setConfidence(this.getLabel().getMapping().mapIndex(k), classProb[k]);
        }
    }

    private void updateEstimates(ExampleSet exampleSet, ContingencyMatrix cm, Attribute[] specialAttributes) {
        for (Example example : exampleSet) {
            int predicted = (int)example.getPredictedLabel();
            for (int j = 0; j < cm.getNumberOfClasses(); ++j) {
                double liftRatioCurrent = cm.getLiftRatio(j, predicted);
                if (Double.isNaN(liftRatioCurrent)) {
                    this.logWarning("Ignoring non-applicable model.");
                    continue;
                }
                if (Double.isInfinite(liftRatioCurrent)) {
                    if (example.getValue(specialAttributes[j]) == 0.0) continue;
                    for (int k = 0; k < specialAttributes.length; ++k) {
                        example.setValue(specialAttributes[k], 0.0);
                    }
                    example.setValue(specialAttributes[j], liftRatioCurrent);
                    continue;
                }
                double oldValue = example.getValue(specialAttributes[j]);
                if (Double.isNaN(oldValue)) {
                    this.logWarning("Found NaN value in intermediate odds ratio estimates!");
                }
                if (Double.isInfinite(oldValue)) continue;
                example.setValue(specialAttributes[j], oldValue * liftRatioCurrent);
            }
        }
    }

    public static boolean adjustIntermediateProducts(double[] products, double[] liftFactors) {
        for (int i = 0; i < liftFactors.length; ++i) {
            if (Double.isNaN(liftFactors[i])) {
                LogService.getGlobal().log("Ignoring non-applicable model.", 5);
                continue;
            }
            if (Double.isInfinite(liftFactors[i])) {
                if (products[i] == 0.0) continue;
                for (int j = 0; j < products.length; ++j) {
                    products[j] = 0.0;
                }
                products[i] = liftFactors[i];
                return true;
            }
            int n = i;
            products[n] = products[n] * liftFactors[i];
            if (!Double.isNaN(products[i])) continue;
            LogService.getGlobal().log("Found NaN value in intermediate odds ratio estimates!", 5);
        }
        return false;
    }

    public double[] getModelWeights() throws OperatorException {
        if (this.getLabel().getMapping().size() != 2) {
            throw new UserError(null, 114, "BayBoostModel", this.getLabel());
        }
        int maxWeight = 10;
        int pos = this.getLabel().getMapping().getPositiveIndex();
        int neg = this.getLabel().getMapping().getNegativeIndex();
        double[] weights = new double[this.getNumberOfModels() + 1];
        double odds = this.getPriorOfClass(pos) / this.getPriorOfClass(neg);
        weights[0] = Math.log(odds);
        for (int i = 1; i < weights.length; ++i) {
            double[] liftRatiosPos = this.getFactorsForModel(i - 1, pos);
            double logPosRatio = Math.log(liftRatiosPos[pos]);
            logPosRatio = Math.min((double)maxWeight, Math.max((double)(-maxWeight), logPosRatio));
            double[] liftRatiosNeg = this.getFactorsForModel(i - 1, neg);
            double logNegRatio = Math.log(liftRatiosNeg[pos]);
            double indep = (logPosRatio + (logNegRatio = Math.min((double)maxWeight, Math.max((double)(-maxWeight), logNegRatio)))) / 2.0;
            if (com.rapidminer.tools.Tools.isEqual(indep, maxWeight) || com.rapidminer.tools.Tools.isEqual(indep, -maxWeight)) {
                logPosRatio = 10.0 * indep;
                indep = 0.0;
            }
            weights[0] = weights[0] + indep;
            weights[i] = logPosRatio -= indep;
        }
        return weights;
    }

    @Override
    public List<String> getModelNames() {
        LinkedList<String> names = new LinkedList<String>();
        for (int i = 0; i < this.getNumberOfModels(); ++i) {
            names.add("Model " + (i + 1));
        }
        return names;
    }

    public List<Model> getModels() {
        LinkedList<Model> models = new LinkedList<Model>();
        for (int i = 0; i < this.getNumberOfModels(); ++i) {
            models.add(this.getModel(i));
        }
        return models;
    }
}

