/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.meta;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.AttributeFactory;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleReader;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.UserError;
import edu.udo.cs.yale.operator.learner.IOModel;
import edu.udo.cs.yale.operator.learner.Model;
import edu.udo.cs.yale.tools.LogService;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;

public class BayBoostModel
extends IOModel {
    public static final String ID = "YALE BayBoost Model";
    private static final int FILE_MODEL = 1;
    private static final int IO_MODEL = 2;
    private List modelInfo;
    private double[] priors;
    private boolean crispPredictions;
    private int maxModelNumber = -1;
    private static final String MAX_MODEL_NUMBER = "iteration";
    private static final String CONV_TO_CRISP = "crisp";
    private double threshold = 0.5;
    private static final String PRED_TO_FILE = "predictions_to_file";
    private File predictionsFile = null;
    private boolean print_to_stdout = false;

    public BayBoostModel(Attribute label) {
        super(label);
    }

    public BayBoostModel(Attribute label, boolean crispPredictions) {
        super(label);
        this.crispPredictions = crispPredictions;
    }

    public BayBoostModel(Attribute label, List modelInfo, double[] priors, boolean crispPredictions) {
        super(label);
        this.modelInfo = modelInfo;
        this.priors = priors;
        this.crispPredictions = crispPredictions;
    }

    public void setPredictionParameter(String name, String value) throws OperatorException {
        if (name.equalsIgnoreCase("print_to_stdout")) {
            this.print_to_stdout = true;
            return;
        }
        if (name.equalsIgnoreCase(PRED_TO_FILE)) {
            if (value != null) {
                String filename = value;
                File file = new File(filename);
                if (file.exists()) {
                    file.delete();
                }
                try {
                    file.createNewFile();
                }
                catch (IOException e) {
                    throw new UserError(null, 303, filename, (Object)e.getMessage());
                }
                this.predictionsFile = file;
                return;
            }
        } else if (name.equalsIgnoreCase(MAX_MODEL_NUMBER)) {
            try {
                this.maxModelNumber = Integer.parseInt(value);
                return;
            }
            catch (NumberFormatException e) {}
        } else if (name.equalsIgnoreCase(CONV_TO_CRISP)) {
            if (!this.crispPredictions && value != null && value.length() > 0) {
                try {
                    this.threshold = Double.parseDouble(value.trim());
                }
                catch (NumberFormatException e) {
                    LogService.logMessage("Ignoring invalid value '" + value + "' for parameter '" + CONV_TO_CRISP + " in BayBoostModel.", 4);
                    this.threshold = 0.5;
                }
            }
            this.crispPredictions = true;
            return;
        }
        super.setPredictionParameter(name, value);
    }

    public String getIdentifier() {
        return ID;
    }

    public void readData(ObjectInputStream in) throws IOException {
        Vector<Object[]> modelList = new Vector<Object[]>();
        int numModels = in.readInt();
        this.crispPredictions = in.readBoolean();
        for (int i = 0; i < numModels; ++i) {
            Model model = Model.readModel(in);
            int rows = in.readInt();
            int cols = in.readInt();
            double[][] factors = new double[rows][cols];
            for (int j = 0; j < rows; ++j) {
                for (int k = 0; k < cols; ++k) {
                    factors[j][k] = in.readDouble();
                }
            }
            modelList.add(new Object[]{model, factors});
        }
        double[] classPriors = new double[in.readInt()];
        for (int i = 0; i < classPriors.length; ++i) {
            classPriors[i] = in.readDouble();
        }
        this.modelInfo = modelList;
        this.priors = classPriors;
    }

    public void writeData(ObjectOutputStream out) throws IOException {
        List modelList = this.modelInfo;
        out.writeInt(modelList.size());
        out.writeBoolean(this.crispPredictions);
        Iterator it = modelList.iterator();
        while (it.hasNext()) {
            Object[] obj = (Object[])it.next();
            Model model = (Model)obj[0];
            double[][] factors = (double[][])obj[1];
            model.writeModel(out);
            int rows = factors.length;
            int cols = rows > 0 && factors[0] != null ? factors[0].length : 0;
            out.writeInt(rows);
            out.writeInt(cols);
            for (int j = 0; j < rows; ++j) {
                for (int k = 0; k < cols; ++k) {
                    out.writeDouble(factors[j][k]);
                }
            }
        }
        out.writeInt(this.priors.length);
        for (int i = 0; i < this.priors.length; ++i) {
            out.writeDouble(this.priors[i]);
        }
    }

    public String toString() {
        String result = super.toString() + "\nNumber of inner models: " + this.getNumberOfModels();
        for (int i = 0; i < this.getNumberOfModels(); ++i) {
            Model model = this.getModel(i);
            result = result + (i > 0 ? "\n" : "") + "(Embedded model #" + i + "):" + model.toResultString();
        }
        return result;
    }

    public int getNumberOfModels() {
        if (this.maxModelNumber >= 0) {
            return Math.min(this.maxModelNumber, this.modelInfo.size());
        }
        return this.modelInfo.size();
    }

    private double[] getFactorsForModel(int modelNr, int predicted) {
        Object[] obj = (Object[])this.modelInfo.get(modelNr);
        double[][] factor = (double[][])obj[1];
        return factor[predicted];
    }

    private double getPriorOfClass(int classIndex) {
        return this.priors[classIndex];
    }

    public double[] getPriors() {
        double[] result = new double[this.priors.length];
        System.arraycopy(this.priors, 0, result, 0, result.length);
        return result;
    }

    public Model getModel(int index) {
        Object[] obj = (Object[])this.modelInfo.get(index);
        return (Model)obj[0];
    }

    public double[][] getBiasMatrix(int modelNumber) {
        Object[] obj = (Object[])this.modelInfo.get(modelNumber);
        double[][] biasMatrix = (double[][])obj[1];
        double[][] result = new double[biasMatrix.length][];
        for (int i = 0; i < result.length; ++i) {
            result[i] = new double[biasMatrix[i].length];
            System.arraycopy(biasMatrix[i], 0, result[i], 0, result[i].length);
        }
        return result;
    }

    public void apply(ExampleSet exampleSet) throws OperatorException {
        PrintStream predOut = null;
        if (this.predictionsFile != null) {
            try {
                predOut = new PrintStream(new BufferedOutputStream(new FileOutputStream(this.predictionsFile)));
            }
            catch (IOException e) {
                throw new UserError(null, 303, this.predictionsFile.getName(), (Object)e.getMessage());
            }
        }
        int posIndex = exampleSet.getLabel().isBooleanClassification() ? exampleSet.getLabel().getPositiveIndex() - 0 : -4711;
        ExampleSet[] eSet = new ExampleSet[this.getNumberOfModels()];
        for (int i = 0; i < this.getNumberOfModels(); ++i) {
            Model model = this.getModel(i);
            eSet[i] = (ExampleSet)exampleSet.clone();
            model.createPredictedLabel(eSet[i]);
            model.apply(eSet[i]);
        }
        ExampleReader[] reader = new ExampleReader[eSet.length];
        for (int r = 0; r < reader.length; ++r) {
            reader[r] = eSet[r].getExampleReader();
        }
        int errors = 0;
        double probCorrect = 0.0;
        ExampleReader originalReader = exampleSet.getExampleReader();
        while (originalReader.hasNext()) {
            int mapIndex;
            double[] intermediateProducts = new double[this.getLabel().getValues().size()];
            for (int k = 0; k < intermediateProducts.length; ++k) {
                double pri = this.getPriorOfClass(k);
                intermediateProducts[k] = pri = pri == 1.0 ? Double.POSITIVE_INFINITY : pri / (1.0 - pri);
            }
            boolean classKnown = false;
            for (int k = 0; k < reader.length; ++k) {
                int predicted;
                double[] biasFactors;
                Example e = reader[k].next();
                if (predOut != null) {
                    predOut.print(e.getPredictedLabel() + " ");
                }
                if (!classKnown && !(classKnown = BayBoostModel.adjustIntermediateProducts(intermediateProducts, biasFactors = this.getFactorsForModel(k, predicted = (int)e.getPredictedLabel() - 0)))) continue;
            }
            double probSum = 0.0;
            double[] classProb = new double[intermediateProducts.length];
            int bestIndex = 0;
            for (int n = 0; n < classProb.length; ++n) {
                classProb[n] = intermediateProducts[n] == Double.POSITIVE_INFINITY ? 1.0 : intermediateProducts[n] / (1.0 + intermediateProducts[n]);
                probSum += classProb[n];
                if (!(classProb[n] > classProb[bestIndex])) continue;
                bestIndex = n;
            }
            if (probSum != 1.0) {
                int k = 0;
                while (k < classProb.length) {
                    int n = k++;
                    classProb[n] = classProb[n] / probSum;
                }
            }
            Example example = originalReader.next();
            if (predOut != null) {
                predOut.println(example.getLabel());
            }
            if (exampleSet.getLabel().isBooleanClassification() && this.threshold != 0.5) {
                this.threshold = this.threshold >= 0.0 && this.threshold <= 1.0 ? this.threshold : 0.5;
                mapIndex = classProb[posIndex] >= this.threshold ? posIndex : 1 - posIndex;
            } else {
                mapIndex = bestIndex + 0;
            }
            example.setPredictedLabel(mapIndex);
            for (int i = 0; i < classProb.length; ++i) {
                example.setConfidence(exampleSet.getLabel().mapIndex(i + 0), classProb[i]);
            }
            int correctLabel = (int)example.getLabel() - 0;
            if (bestIndex != correctLabel) {
                ++errors;
            }
            probCorrect += classProb[correctLabel];
        }
        LogService.logMessage("< Number of models: " + this.getNumberOfModels() + " - Total number of errors: " + errors + ", prob. to predict correct label: " + (probCorrect /= (double)exampleSet.getSize()) + " >", 2);
        if (predOut != null) {
            predOut.close();
            predOut = null;
        }
    }

    public static boolean adjustIntermediateProducts(double[] products, double[] biasFactors) {
        for (int i = 0; i < biasFactors.length; ++i) {
            if (biasFactors[i] == Double.NaN) continue;
            if (biasFactors[i] == Double.POSITIVE_INFINITY) {
                if (products[i] == 0.0) continue;
                for (int j = 0; j < products.length; ++j) {
                    products[j] = 0.0;
                }
                products[i] = biasFactors[i];
                return true;
            }
            int n = i;
            products[n] = products[n] * biasFactors[i];
        }
        return false;
    }

    public Attribute createPredictedLabel(ExampleSet exampleSet, String name) {
        Attribute predictedLabel = super.createPredictedLabel(exampleSet, name);
        if (!this.crispPredictions) {
            predictedLabel = exampleSet.replaceAttribute(predictedLabel, AttributeFactory.changeValueType(predictedLabel, 4));
        }
        return predictedLabel;
    }

    public double[] getModelWeights() throws OperatorException {
        if (this.getLabel().getNumberOfValues() != 2) {
            throw new OperatorException("BayBoostModel.getModelWeights() is only applicable for binary prediction tasks.");
        }
        int maxWeight = 10;
        int pos = this.getLabel().getPositiveIndex();
        int neg = this.getLabel().getNegativeIndex();
        double[] weights = new double[this.getNumberOfModels() + 1];
        double odds = this.getPriorOfClass(pos) / this.getPriorOfClass(neg);
        weights[0] = Math.log(odds);
        for (int i = 1; i < weights.length; ++i) {
            double[] biasRatiosPos = this.getFactorsForModel(i - 1, pos);
            double logPosRatio = Math.log(biasRatiosPos[pos]);
            logPosRatio = Math.min((double)maxWeight, Math.max((double)(-maxWeight), logPosRatio));
            double[] biasRatiosNeg = this.getFactorsForModel(i - 1, neg);
            double logNegRatio = Math.log(biasRatiosNeg[pos]);
            double indep = (logPosRatio + (logNegRatio = Math.min((double)maxWeight, Math.max((double)(-maxWeight), logNegRatio)))) / 2.0;
            if (indep == (double)maxWeight || indep == (double)(-maxWeight)) {
                logPosRatio = 10.0 * indep;
                indep = 0.0;
            }
            weights[0] = weights[0] + indep;
            weights[i] = logPosRatio -= indep;
        }
        return weights;
    }
}

