/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.meta;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.example.SplittedExampleSet;
import edu.udo.cs.yale.operator.MissingIOObjectException;
import edu.udo.cs.yale.operator.Model;
import edu.udo.cs.yale.operator.OperatorDescription;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.Value;
import edu.udo.cs.yale.operator.learner.LearnerCapability;
import edu.udo.cs.yale.operator.learner.PredictionModel;
import edu.udo.cs.yale.operator.learner.meta.AbstractMetaLearner;
import edu.udo.cs.yale.operator.learner.meta.BayBoostBaseModelInfo;
import edu.udo.cs.yale.operator.learner.meta.BayBoostModel;
import edu.udo.cs.yale.operator.learner.meta.ContingencyMatrix;
import edu.udo.cs.yale.operator.learner.meta.WeightedPerformanceMeasures;
import edu.udo.cs.yale.operator.parameter.ParameterType;
import edu.udo.cs.yale.operator.parameter.ParameterTypeBoolean;
import edu.udo.cs.yale.operator.parameter.ParameterTypeDouble;
import edu.udo.cs.yale.operator.parameter.ParameterTypeInt;
import edu.udo.cs.yale.tools.LogService;
import edu.udo.cs.yale.tools.Tools;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class BayesianBoosting
extends AbstractMetaLearner {
    public static final String NUM_OF_ITERATIONS = "iterations";
    public static final String INTERNAL_BOOTSTRAP = "use_subset_for_training";
    public static final String EQUALLY_PROB_LABELS = "rescale_label_priors";
    public static final String ALLOW_MARGINAL_SKEWS = "allow_marginal_skews";
    public static final double MIN_ADVANTAGE = 0.001;
    private Model startModel;
    protected int currentIteration;
    private double performance = 0.0;
    private double[] oldWeights;

    public BayesianBoosting(OperatorDescription description) {
        super(description);
        this.addValue(new Value("performance", "The performance."){

            public double getValue() {
                return BayesianBoosting.this.performance;
            }
        });
        this.addValue(new Value("iteration", "The current iteration."){

            public double getValue() {
                return BayesianBoosting.this.currentIteration;
            }
        });
    }

    @Override
    public boolean supportsCapability(LearnerCapability lc) {
        if (lc == LearnerCapability.NUMERICAL_CLASS || lc == LearnerCapability.POLYNOMINAL_CLASS) {
            return false;
        }
        return super.supportsCapability(lc);
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeBoolean(EQUALLY_PROB_LABELS, "Specifies whether the proportion of labels should be equal by construction after first iteration .", false));
        types.add(new ParameterTypeDouble(INTERNAL_BOOTSTRAP, "Fraction of examples used for training, remaining ones are used to estimate the confusion matrix. Set to 1 to turn off test set.", 0.0, 1.0, 1.0));
        types.add(new ParameterTypeInt(NUM_OF_ITERATIONS, "The maximum number of iterations.", 1, Integer.MAX_VALUE, 10));
        types.add(new ParameterTypeBoolean(ALLOW_MARGINAL_SKEWS, "Allow to skew the marginal distribution (P(x)) during learning.", true));
        return types;
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        this.readOptionalParameters();
        double[] classPriors = this.prepareWeights(exampleSet);
        double maxPrior = Double.NEGATIVE_INFINITY;
        double sumPriors = 0.0;
        int i = 0;
        while (i < classPriors.length) {
            if (classPriors[i] > maxPrior) {
                maxPrior = classPriors[i];
            }
            sumPriors += classPriors[i];
            ++i;
        }
        BayBoostModel model = Tools.isEqual(sumPriors, maxPrior) ? new BayBoostModel(exampleSet.getAttributes().getLabel(), new Vector<BayBoostBaseModelInfo>(), classPriors) : this.trainBoostingModel(exampleSet, classPriors);
        if (this.oldWeights != null) {
            Iterator reader = exampleSet.iterator();
            int i2 = 0;
            while (reader.hasNext() && i2 < this.oldWeights.length) {
                ((Example)reader.next()).setWeight(this.oldWeights[i2++]);
            }
        } else {
            Attribute weight = exampleSet.getAttributes().getWeight();
            exampleSet.getAttributes().remove(weight);
            exampleSet.getExampleTable().removeAttribute(weight);
        }
        return model;
    }

    protected double[] prepareWeights(ExampleSet exampleSet) {
        Attribute weightAttr = exampleSet.getAttributes().getWeight();
        if (weightAttr == null) {
            this.oldWeights = null;
            this.performance = exampleSet.size();
            return this.createNewWeightAttribute(exampleSet);
        }
        this.oldWeights = new double[exampleSet.size()];
        double[] priors = new double[exampleSet.getAttributes().getLabel().getMapping().size()];
        double totalWeight = 0.0;
        Iterator reader = exampleSet.iterator();
        int i = 0;
        while (reader.hasNext() && i < this.oldWeights.length) {
            Example example = (Example)reader.next();
            if (example != null) {
                double weight;
                this.oldWeights[i] = weight = example.getWeight();
                int label = (int)example.getLabel();
                if (label >= 0 && label < priors.length) {
                    int n = label;
                    priors[n] = priors[n] + weight;
                    totalWeight += weight;
                } else {
                    example.setWeight(0.0);
                }
            }
            ++i;
        }
        this.performance = totalWeight;
        i = 0;
        while (i < priors.length) {
            int n = i++;
            priors[n] = priors[n] / totalWeight;
        }
        return priors;
    }

    /*
     * Unable to fully structure code
     */
    private double[] createNewWeightAttribute(ExampleSet exampleSet) {
        block2: {
            edu.udo.cs.yale.example.Tools.createWeightAttribute(exampleSet);
            exRead = exampleSet.iterator();
            numClasses = exampleSet.getAttributes().getLabel().getMapping().getValues().size();
            classPriors = new double[numClasses];
            total = exampleSet.size();
            invTotal = 1.0 / (double)total;
            if (this.getParameterAsBoolean("rescale_label_priors")) ** GOTO lbl18
            while (exRead.hasNext()) {
                example = (Example)exRead.next();
                example.setWeight(1.0);
                v0 = (int)example.getLabel();
                classPriors[v0] = classPriors[v0] + invTotal;
            }
            break block2;
lbl-1000:
            // 1 sources

            {
                v1 = (int)((Example)exRead.next()).getLabel();
                classPriors[v1] = classPriors[v1] + invTotal;
lbl18:
                // 2 sources

                ** while (exRead.hasNext())
            }
lbl19:
            // 1 sources

            this.rescaleToEqualPriors(exampleSet, classPriors);
        }
        return classPriors;
    }

    private void rescaleToEqualPriors(ExampleSet exampleSet, double[] currentPriors) {
        double[] weights = new double[currentPriors.length];
        int i = 0;
        while (i < weights.length) {
            weights[i] = 1.0 / ((double)weights.length * currentPriors[i]);
            ++i;
        }
        for (Example example : exampleSet) {
            example.setWeight(weights[(int)example.getLabel()]);
        }
    }

    protected Model trainBaseModel(ExampleSet exampleSet) throws OperatorException {
        Model model = this.applyInnerLearner(exampleSet);
        return model;
    }

    private void readOptionalParameters() {
        try {
            this.startModel = this.getInput(Model.class);
        }
        catch (MissingIOObjectException e) {
            LogService.logMessage(String.valueOf(this.getName()) + ": No model found in input.", 2);
        }
    }

    private void applyPriorModel(ExampleSet trainingSet, List<BayBoostBaseModelInfo> modelInfo) throws OperatorException {
        if (this.startModel != null) {
            this.startModel.apply(trainingSet);
            WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(trainingSet);
            this.reweightExamples(wp, trainingSet);
            modelInfo.add(new BayBoostBaseModelInfo(this.startModel, wp.getContingencyMatrix()));
            PredictionModel.removePredictedLabel(trainingSet);
        }
    }

    private BayBoostModel trainBoostingModel(ExampleSet trainingSet, double[] classPriors) throws OperatorException {
        Vector<BayBoostBaseModelInfo> modelInfo = new Vector<BayBoostBaseModelInfo>();
        this.applyPriorModel(trainingSet, modelInfo);
        double splitRatio = this.getParameterAsDouble(INTERNAL_BOOTSTRAP);
        boolean bootstrap = splitRatio > 0.0 && splitRatio < 1.0;
        LogService.logMessage(bootstrap ? "Bootstrapping enabled." : "Bootstrapping disabled.", 2);
        boolean allowSkew = this.getParameterAsBoolean(ALLOW_MARGINAL_SKEWS);
        SplittedExampleSet splittedSet = null;
        if (bootstrap) {
            splittedSet = new SplittedExampleSet(trainingSet, splitRatio, 1, -1);
        }
        int iterations = this.getParameterAsInt(NUM_OF_ITERATIONS);
        int i = 0;
        while (i < iterations) {
            WeightedPerformanceMeasures wp;
            Model model;
            this.currentIteration = i;
            if (bootstrap) {
                splittedSet.selectSingleSubset(0);
                model = this.trainBaseModel(splittedSet);
                model.apply(trainingSet);
                wp = new WeightedPerformanceMeasures(splittedSet);
                WeightedPerformanceMeasures.reweightExamples(splittedSet, wp.getContingencyMatrix(), allowSkew);
                splittedSet.selectSingleSubset(1);
                wp = new WeightedPerformanceMeasures(splittedSet);
                this.performance = WeightedPerformanceMeasures.reweightExamples(splittedSet, wp.getContingencyMatrix(), allowSkew);
            } else {
                model = this.trainBaseModel(trainingSet);
                model.apply(trainingSet);
                wp = new WeightedPerformanceMeasures(trainingSet);
                this.performance = this.reweightExamples(wp, trainingSet);
            }
            PredictionModel.removePredictedLabel(trainingSet);
            if (classPriors.length == 2) {
                this.debugMessage(wp);
            }
            if (wp.getNumberOfNonEmptyClasses() < 2) {
                modelInfo.add(new BayBoostBaseModelInfo(model, wp.getContingencyMatrix()));
                break;
            }
            ContingencyMatrix cm = wp.getContingencyMatrix();
            modelInfo.add(new BayBoostBaseModelInfo(model, cm));
            if (!this.isModelUseful(cm)) {
                LogService.logMessage("Discard model because of low advantage on training data.", 2);
                modelInfo.remove(modelInfo.size() - 1);
                break;
            }
            if (this.performance == 0.0) break;
            this.inApplyLoop();
            ++i;
        }
        return new BayBoostModel(trainingSet.getAttributes().getLabel(), modelInfo, classPriors);
    }

    private void debugMessage(WeightedPerformanceMeasures wp) {
        String message = String.valueOf(Tools.getLineSeparator()) + "Model learned - training performance of base learner:" + Tools.getLineSeparator() + "TPR: " + wp.getProbability(0, 0) + " FPR: " + wp.getProbability(1, 0) + " | Positively predicted: " + (wp.getProbability(1, 0) + wp.getProbability(0, 0)) + Tools.getLineSeparator() + "FNR: " + wp.getProbability(0, 1) + " TNR: " + wp.getProbability(1, 1) + " | Negatively predicted: " + (wp.getProbability(0, 1) + wp.getProbability(1, 1)) + Tools.getLineSeparator() + "Positively labelled: " + (wp.getProbability(0, 0) + wp.getProbability(0, 1)) + Tools.getLineSeparator() + "Negatively labelled: " + (wp.getProbability(1, 0) + wp.getProbability(1, 1));
        LogService.logMessage(message, 2);
    }

    protected double reweightExamples(WeightedPerformanceMeasures wp, ExampleSet exampleSet) throws OperatorException {
        boolean allowMarginalSkews = this.getParameterAsBoolean(ALLOW_MARGINAL_SKEWS);
        double remainingWeight = WeightedPerformanceMeasures.reweightExamples(exampleSet, wp.getContingencyMatrix(), allowMarginalSkews);
        return remainingWeight;
    }

    private boolean isModelUseful(ContingencyMatrix cm) {
        return true;
    }
}

