/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.meta;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.AttributeFactory;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.UserError;
import edu.udo.cs.yale.operator.learner.PredictionModel;
import edu.udo.cs.yale.operator.learner.meta.SDRulesetInduction;
import edu.udo.cs.yale.tools.Tools;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

public class SDEnsemble
extends PredictionModel {
    public static final short RULE_COMBINE_ADDITIVE = 1;
    public static final short RULE_COMBINE_MULTIPLY = 2;
    private List modelInfo;
    private int maxModelNumber = -1;
    private static final String MAX_MODEL_NUMBER = "iteration";
    private static final String PRED_TO_FILE = "predictions_to_file";
    private File predictionsFile = null;
    private boolean print_to_stdout = false;
    private double[] priors;

    public SDEnsemble() {
    }

    public SDEnsemble(Attribute label) {
        super(label);
    }

    public SDEnsemble(Attribute label, List modelInfo, double[] priors, short combinationMethod) {
        super(label);
        this.modelInfo = modelInfo;
        this.priors = priors;
    }

    public String toString() {
        StringBuffer result = new StringBuffer(String.valueOf(super.toString()) + Tools.getLineSeparator() + "Number of inner models: " + this.getNumberOfModels());
        int i = 0;
        while (i < this.getNumberOfModels()) {
            PredictionModel model = this.getModel(i);
            result.append(String.valueOf(i > 0 ? Tools.getLineSeparator() : "") + "(Embedded model #" + i + "):" + model.toResultString());
            ++i;
        }
        return result.toString();
    }

    public void setParameter(String name, String value) throws OperatorException {
        if (name.equalsIgnoreCase("print_to_stdout")) {
            this.print_to_stdout = true;
            return;
        }
        if (name.equalsIgnoreCase(PRED_TO_FILE)) {
            if (value != null) {
                String filename = value;
                File file = new File(filename);
                if (file.exists()) {
                    file.delete();
                }
                try {
                    file.createNewFile();
                }
                catch (IOException e) {
                    throw new UserError(null, 303, filename, (Object)e.getMessage());
                }
                this.predictionsFile = file;
                return;
            }
        } else {
            try {
                if (name.equalsIgnoreCase(MAX_MODEL_NUMBER)) {
                    this.maxModelNumber = Integer.parseInt(value);
                    return;
                }
            }
            catch (NumberFormatException numberFormatException) {
                // empty catch block
            }
        }
        super.setParameter(name, value);
    }

    public int getNumberOfModels() {
        if (this.maxModelNumber >= 0) {
            return Math.min(this.maxModelNumber, this.modelInfo.size());
        }
        return this.modelInfo.size();
    }

    private double[] getWeightsForModel(int modelNr, int predicted) {
        Object[] obj = (Object[])this.modelInfo.get(modelNr);
        double[][] weight = (double[][])obj[1];
        return weight[predicted];
    }

    private double getPriorOfClass(int classIndex) {
        return this.priors[classIndex];
    }

    public PredictionModel getModel(int index) {
        Object[] obj = (Object[])this.modelInfo.get(index);
        return (PredictionModel)obj[0];
    }

    public void performPrediction(ExampleSet exampleSet, Attribute predictedLabelAttribute) throws OperatorException {
        PrintStream predOut = null;
        if (this.predictionsFile != null) {
            try {
                predOut = new PrintStream(new BufferedOutputStream(new FileOutputStream(this.predictionsFile)));
            }
            catch (IOException e) {
                throw new UserError(null, 303, this.predictionsFile.getName(), (Object)e.getMessage());
            }
        }
        ExampleSet[] eSet = new ExampleSet[this.getNumberOfModels()];
        int i = 0;
        while (i < this.getNumberOfModels()) {
            PredictionModel model = this.getModel(i);
            eSet[i] = (ExampleSet)exampleSet.clone();
            model.apply(eSet[i]);
            ++i;
        }
        ArrayList reader = new ArrayList(eSet.length);
        int r = 0;
        while (r < eSet.length) {
            reader.add(eSet[r].iterator());
            ++r;
        }
        Iterator originalReader = exampleSet.iterator();
        int posIndex = SDRulesetInduction.getPosIndex(exampleSet.getAttributes().getLabel());
        int[] numCovered = new int[this.getNumberOfModels()];
        int[] posCovered = new int[this.getNumberOfModels()];
        int posTotal = 0;
        while (originalReader.hasNext()) {
            Example example = (Example)originalReader.next();
            double sumPos = 0.0;
            double sumTotal = 0.0;
            int k = 0;
            while (k < reader.size()) {
                Example e = (Example)((Iterator)reader.get(k)).next();
                if (predOut != null) {
                    predOut.print(String.valueOf(e.getPredictedLabel()) + " ");
                }
                int predicted = (int)e.getPredictedLabel();
                double[] modelWeights = this.getWeightsForModel(k, predicted);
                int i2 = 0;
                while (i2 < modelWeights.length) {
                    sumTotal += modelWeights[i2];
                    ++i2;
                }
                sumPos += modelWeights[posIndex];
                if (this.print_to_stdout) {
                    int label = (int)e.getLabel();
                    if (k == 0 && label == posIndex) {
                        ++posTotal;
                    }
                    if (predicted == posIndex) {
                        int n = k;
                        numCovered[n] = numCovered[n] + 1;
                        if (label == predicted) {
                            int n2 = k;
                            posCovered[n2] = posCovered[n2] + 1;
                        }
                    }
                }
                ++k;
            }
            if (predOut != null) {
                predOut.println(example.getLabel());
            }
            sumPos = sumTotal > 0.0 ? (sumPos /= sumTotal) : this.getPriorOfClass(posIndex);
            example.setPredictedLabel(sumPos);
        }
        if (predOut != null) {
            predOut.close();
            predOut = null;
        }
        if (this.print_to_stdout) {
            double avgCov = 0.0;
            double avgWRacc = 0.0;
            double avgLift = 0.0;
            int i3 = 0;
            while (i3 < this.getNumberOfModels()) {
                double coverage = (double)numCovered[i3] / (double)exampleSet.size();
                double precision = (double)posCovered[i3] / (double)numCovered[i3];
                double priorPos = (double)posTotal / (double)exampleSet.size();
                double bias = Math.abs(precision - priorPos);
                double wracc = coverage * bias;
                double lift = Math.max(precision / priorPos, (1.0 - precision) / (1.0 - priorPos));
                double dualCov = 1.0 - coverage;
                double posNotCov = priorPos - (double)posCovered[i3] / (double)exampleSet.size();
                double dualPrec = posNotCov / dualCov;
                double dualBias = Math.abs(dualPrec - priorPos);
                double dualWracc = dualCov * dualBias;
                if (coverage == 0.0 || dualWracc > wracc) {
                    coverage = dualCov;
                    wracc = dualWracc;
                    lift = dualPrec / priorPos;
                }
                avgCov += coverage;
                avgWRacc += Double.isNaN(wracc) ? 0.0 : wracc;
                avgLift += Double.isNaN(lift) ? 1.0 : lift;
                ++i3;
            }
            System.out.println("Average ruleset performance: [Number of rules: " + this.getNumberOfModels() + "], [Cov: " + (avgCov /= (double)this.getNumberOfModels()) + "], [Lift: " + (avgLift /= (double)this.getNumberOfModels()) + "], [WRAcc: " + (avgWRacc /= (double)this.getNumberOfModels()) + "]");
        }
    }

    protected Attribute createPredictedLabel(ExampleSet exampleSet) {
        Attribute predictedLabel = super.createPredictedLabel(exampleSet);
        return exampleSet.getAttributes().replace(predictedLabel, AttributeFactory.changeValueType(predictedLabel, 4));
    }
}

