/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.meta.MetaModel;
import com.rapidminer.operator.learner.meta.SDRulesetInduction;
import com.rapidminer.tools.LogService;
import com.rapidminer.tools.Tools;
import com.rapidminer.tools.container.Pair;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public class SDEnsemble
extends PredictionModel
implements MetaModel {
    private static final long serialVersionUID = 1320495411014477089L;
    public static final short RULE_COMBINE_ADDITIVE = 1;
    public static final short RULE_COMBINE_MULTIPLY = 2;
    private List<Pair<Model, double[][]>> modelInfo;
    private int maxModelNumber = -1;
    private static final String MAX_MODEL_NUMBER = "iteration";
    private static final String PRED_TO_FILE = "predictions_to_file";
    private File predictionsFile = null;
    private boolean print_to_stdout = false;
    private double[] priors;

    public SDEnsemble(ExampleSet exampleSet, List<Pair<Model, double[][]>> modelInfo, double[] priors, short combinationMethod) {
        super(exampleSet);
        this.modelInfo = modelInfo;
        this.priors = priors;
    }

    @Override
    public String toString() {
        StringBuffer result = new StringBuffer(super.toString() + Tools.getLineSeparator() + "Number of inner models: " + this.getNumberOfModels());
        for (int i = 0; i < this.getNumberOfModels(); ++i) {
            Model model = this.getModel(i);
            result.append((i > 0 ? Tools.getLineSeparator() : "") + "(Embedded model #" + i + "):" + model.toResultString());
        }
        return result.toString();
    }

    public void setParameter(String name, String value) throws OperatorException {
        if (name.equalsIgnoreCase("print_to_stdout")) {
            this.print_to_stdout = true;
            return;
        }
        if (name.equalsIgnoreCase(PRED_TO_FILE)) {
            if (value != null) {
                boolean result;
                String filename = value;
                File file = new File(filename);
                if (file.exists() && !(result = file.delete())) {
                    LogService.getGlobal().logError("Cannot delete file: " + file);
                }
                try {
                    file.createNewFile();
                }
                catch (IOException e) {
                    throw new UserError(null, 303, filename, e.getMessage());
                }
                this.predictionsFile = file;
                return;
            }
        } else {
            try {
                if (name.equalsIgnoreCase(MAX_MODEL_NUMBER)) {
                    this.maxModelNumber = Integer.parseInt(value);
                    return;
                }
            }
            catch (NumberFormatException numberFormatException) {
                // empty catch block
            }
        }
        super.setParameter(name, value);
    }

    public int getNumberOfModels() {
        if (this.maxModelNumber >= 0) {
            return Math.min(this.maxModelNumber, this.modelInfo.size());
        }
        return this.modelInfo.size();
    }

    private double[] getWeightsForModel(int modelNr, int predicted) {
        return this.modelInfo.get(modelNr).getSecond()[predicted];
    }

    private double getPriorOfClass(int classIndex) {
        return this.priors[classIndex];
    }

    public Model getModel(int index) {
        return this.modelInfo.get(index).getFirst();
    }

    @Override
    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabelAttribute) throws OperatorException {
        PrintStream predOut = null;
        if (this.predictionsFile != null) {
            try {
                predOut = new PrintStream(new BufferedOutputStream(new FileOutputStream(this.predictionsFile)));
            }
            catch (IOException e) {
                throw new UserError(null, 303, this.predictionsFile.getName(), e.getMessage());
            }
            finally {
                if (predOut != null) {
                    predOut.close();
                }
            }
        }
        ExampleSet[] eSet = new ExampleSet[this.getNumberOfModels()];
        for (int i = 0; i < this.getNumberOfModels(); ++i) {
            Model model = this.getModel(i);
            eSet[i] = (ExampleSet)exampleSet.clone();
            eSet[i] = model.apply(eSet[i]);
        }
        ArrayList reader = new ArrayList(eSet.length);
        for (int r = 0; r < eSet.length; ++r) {
            reader.add(eSet[r].iterator());
        }
        Iterator originalReader = exampleSet.iterator();
        int posIndex = SDRulesetInduction.getPosIndex(exampleSet.getAttributes().getLabel());
        int[] numCovered = new int[this.getNumberOfModels()];
        int[] posCovered = new int[this.getNumberOfModels()];
        int posTotal = 0;
        while (originalReader.hasNext()) {
            Example example = (Example)originalReader.next();
            double sumPos = 0.0;
            double sumTotal = 0.0;
            for (int k = 0; k < reader.size(); ++k) {
                Example e = (Example)((Iterator)reader.get(k)).next();
                if (predOut != null) {
                    predOut.print(e.getPredictedLabel() + " ");
                }
                int predicted = (int)e.getPredictedLabel();
                double[] modelWeights = this.getWeightsForModel(k, predicted);
                for (int i = 0; i < modelWeights.length; ++i) {
                    sumTotal += modelWeights[i];
                }
                sumPos += modelWeights[posIndex];
                if (!this.print_to_stdout) continue;
                int label = (int)e.getLabel();
                if (k == 0 && label == posIndex) {
                    ++posTotal;
                }
                if (predicted != posIndex) continue;
                int n = k;
                numCovered[n] = numCovered[n] + 1;
                if (label != predicted) continue;
                int n2 = k;
                posCovered[n2] = posCovered[n2] + 1;
            }
            if (predOut != null) {
                predOut.println(example.getLabel());
            }
            sumPos = sumTotal > 0.0 ? (sumPos /= sumTotal) : this.getPriorOfClass(posIndex);
            example.setPredictedLabel(sumPos);
        }
        if (predOut != null) {
            predOut.close();
        }
        return exampleSet;
    }

    protected Attribute createPredictedLabel(ExampleSet exampleSet) {
        Attribute predictedLabel = PredictionModel.createPredictedLabel(exampleSet, this.getLabel());
        return exampleSet.getAttributes().replace(predictedLabel, AttributeFactory.changeValueType(predictedLabel, 4));
    }

    @Override
    public List<String> getModelNames() {
        LinkedList<String> names = new LinkedList<String>();
        for (int i = 0; i < this.getNumberOfModels(); ++i) {
            names.add("Model " + (i + 1));
        }
        return names;
    }

    public List<Model> getModels() {
        LinkedList<Model> models = new LinkedList<Model>();
        for (Pair<Model, double[][]> pair : this.modelInfo) {
            models.add(pair.getFirst());
        }
        return models;
    }
}

