/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.meta;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.AttributeFactory;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleReader;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.UserError;
import edu.udo.cs.yale.operator.learner.IOModel;
import edu.udo.cs.yale.operator.learner.Model;
import edu.udo.cs.yale.operator.learner.meta.SDRulesetInduction;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;

public class SDEnsemble
extends IOModel {
    public static final String ID = "SD Ruleset";
    private static final int FILE_MODEL = 1;
    private static final int IO_MODEL = 2;
    public static final short RULE_COMBINE_ADDITIVE = 1;
    public static final short RULE_COMBINE_MULTIPLY = 2;
    private short combinationMethod;
    private List modelInfo;
    private int maxModelNumber = -1;
    private static final String MAX_MODEL_NUMBER = "iteration";
    private static final String PRED_TO_FILE = "predictions_to_file";
    private File predictionsFile = null;
    private boolean print_to_stdout = false;
    private double[] priors;

    public SDEnsemble(Attribute label) {
        super(label);
    }

    public SDEnsemble(Attribute label, List modelInfo, double[] priors, short combinationMethod) {
        super(label);
        this.modelInfo = modelInfo;
        this.priors = priors;
        this.combinationMethod = combinationMethod;
    }

    public String getIdentifier() {
        return ID;
    }

    public void readData(ObjectInputStream in) throws IOException {
        Vector<Object[]> modelList = new Vector<Object[]>();
        int numModels = in.readInt();
        this.combinationMethod = in.readShort();
        for (int i = 0; i < numModels; ++i) {
            Model model = Model.readModel(in);
            int rows = in.readInt();
            int cols = in.readInt();
            double[][] factors = new double[rows][cols];
            for (int j = 0; j < rows; ++j) {
                for (int k = 0; k < cols; ++k) {
                    factors[j][k] = in.readDouble();
                }
            }
            modelList.add(new Object[]{model, factors});
        }
        double[] classPriors = new double[in.readInt()];
        for (int i = 0; i < classPriors.length; ++i) {
            classPriors[i] = in.readDouble();
        }
        this.modelInfo = modelList;
        this.priors = classPriors;
    }

    public void writeData(ObjectOutputStream out) throws IOException {
        List modelList = this.modelInfo;
        out.writeInt(modelList.size());
        out.writeShort(this.combinationMethod);
        Iterator it = modelList.iterator();
        while (it.hasNext()) {
            Object[] obj = (Object[])it.next();
            Model model = (Model)obj[0];
            double[][] factors = (double[][])obj[1];
            model.writeModel(out);
            int rows = factors.length;
            int cols = rows > 0 && factors[0] != null ? factors[0].length : 0;
            out.writeInt(rows);
            out.writeInt(cols);
            for (int j = 0; j < rows; ++j) {
                for (int k = 0; k < cols; ++k) {
                    out.writeDouble(factors[j][k]);
                }
            }
        }
        out.writeInt(this.priors.length);
        for (int i = 0; i < this.priors.length; ++i) {
            out.writeDouble(this.priors[i]);
        }
    }

    public String toString() {
        String result = super.toString() + "\nNumber of inner models: " + this.getNumberOfModels();
        for (int i = 0; i < this.getNumberOfModels(); ++i) {
            Model model = this.getModel(i);
            result = result + (i > 0 ? "\n" : "") + "(Embedded model #" + i + "):" + model.toResultString();
        }
        return result;
    }

    public void setPredictionParameter(String name, String value) throws OperatorException {
        if (name.equalsIgnoreCase("print_to_stdout")) {
            this.print_to_stdout = true;
            return;
        }
        if (name.equalsIgnoreCase(PRED_TO_FILE)) {
            if (value != null) {
                String filename = value;
                File file = new File(filename);
                if (file.exists()) {
                    file.delete();
                }
                try {
                    file.createNewFile();
                }
                catch (IOException e) {
                    throw new UserError(null, 303, filename, (Object)e.getMessage());
                }
                this.predictionsFile = file;
                return;
            }
        } else {
            try {
                if (name.equalsIgnoreCase(MAX_MODEL_NUMBER)) {
                    this.maxModelNumber = Integer.parseInt(value);
                    return;
                }
            }
            catch (NumberFormatException numberFormatException) {
                // empty catch block
            }
        }
        super.setPredictionParameter(name, value);
    }

    public int getNumberOfModels() {
        if (this.maxModelNumber >= 0) {
            return Math.min(this.maxModelNumber, this.modelInfo.size());
        }
        return this.modelInfo.size();
    }

    private double[] getWeightsForModel(int modelNr, int predicted) {
        Object[] obj = (Object[])this.modelInfo.get(modelNr);
        double[][] weight = (double[][])obj[1];
        return weight[predicted];
    }

    private double getPriorOfClass(int classIndex) {
        return this.priors[classIndex];
    }

    public Model getModel(int index) {
        Object[] obj = (Object[])this.modelInfo.get(index);
        return (Model)obj[0];
    }

    public void apply(ExampleSet exampleSet) throws OperatorException {
        PrintStream predOut = null;
        if (this.predictionsFile != null) {
            try {
                predOut = new PrintStream(new BufferedOutputStream(new FileOutputStream(this.predictionsFile)));
            }
            catch (IOException e) {
                throw new UserError(null, 303, this.predictionsFile.getName(), (Object)e.getMessage());
            }
        }
        ExampleSet[] eSet = new ExampleSet[this.getNumberOfModels()];
        for (int i = 0; i < this.getNumberOfModels(); ++i) {
            Model model = this.getModel(i);
            eSet[i] = (ExampleSet)exampleSet.clone();
            model.createPredictedLabel(eSet[i]);
            model.apply(eSet[i]);
        }
        ExampleReader[] reader = new ExampleReader[eSet.length];
        for (int r = 0; r < reader.length; ++r) {
            reader[r] = eSet[r].getExampleReader();
        }
        ExampleReader originalReader = exampleSet.getExampleReader();
        int posIndex = SDRulesetInduction.getPosIndex(exampleSet.getLabel());
        int[] numCovered = new int[this.getNumberOfModels()];
        int[] posCovered = new int[this.getNumberOfModels()];
        int posTotal = 0;
        while (originalReader.hasNext()) {
            Example example = originalReader.next();
            double sumPos = 0.0;
            double sumTotal = 0.0;
            for (int k = 0; k < reader.length; ++k) {
                Example e = reader[k].next();
                if (predOut != null) {
                    predOut.print(e.getPredictedLabel() + " ");
                }
                int predicted = (int)e.getPredictedLabel() - 0;
                double[] modelWeights = this.getWeightsForModel(k, predicted);
                for (int i = 0; i < modelWeights.length; ++i) {
                    sumTotal += modelWeights[i];
                }
                sumPos += modelWeights[posIndex];
                if (!this.print_to_stdout) continue;
                int label = (int)e.getLabel() - 0;
                if (k == 0 && label == posIndex) {
                    ++posTotal;
                }
                if (predicted != posIndex) continue;
                int n = k;
                numCovered[n] = numCovered[n] + 1;
                if (label != predicted) continue;
                int n2 = k;
                posCovered[n2] = posCovered[n2] + 1;
            }
            if (predOut != null) {
                predOut.println(example.getLabel());
            }
            sumPos = sumTotal > 0.0 ? (sumPos /= sumTotal) : this.getPriorOfClass(posIndex);
            example.setPredictedLabel(sumPos);
        }
        if (predOut != null) {
            predOut.close();
            predOut = null;
        }
        if (this.print_to_stdout) {
            double avgCov = 0.0;
            double avgWRacc = 0.0;
            double avgLift = 0.0;
            for (int i = 0; i < this.getNumberOfModels(); ++i) {
                double coverage = (double)numCovered[i] / (double)exampleSet.getSize();
                double precision = (double)posCovered[i] / (double)numCovered[i];
                double priorPos = (double)posTotal / (double)exampleSet.getSize();
                double bias = Math.abs(precision - priorPos);
                double wracc = coverage * bias;
                double lift = Math.max(precision / priorPos, (1.0 - precision) / (1.0 - priorPos));
                double dualCov = 1.0 - coverage;
                double posNotCov = priorPos - (double)posCovered[i] / (double)exampleSet.getSize();
                double dualPrec = posNotCov / dualCov;
                double dualBias = Math.abs(dualPrec - priorPos);
                double dualWracc = dualCov * dualBias;
                if (coverage == 0.0 || dualWracc > wracc) {
                    coverage = dualCov;
                    wracc = dualWracc;
                    lift = dualPrec / priorPos;
                }
                avgCov += coverage;
                avgWRacc += Double.isNaN(wracc) ? 0.0 : wracc;
                avgLift += Double.isNaN(lift) ? 1.0 : lift;
            }
            System.out.println("Average ruleset performance: [Number of rules: " + this.getNumberOfModels() + "], [Cov: " + (avgCov /= (double)this.getNumberOfModels()) + "], [Lift: " + (avgLift /= (double)this.getNumberOfModels()) + "], [WRAcc: " + (avgWRacc /= (double)this.getNumberOfModels()) + "]");
        }
    }

    public Attribute createPredictedLabel(ExampleSet exampleSet, String name) {
        Attribute predictedLabel = super.createPredictedLabel(exampleSet, name);
        return exampleSet.replaceAttribute(predictedLabel, AttributeFactory.changeValueType(predictedLabel, 4));
    }
}

