/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.meta;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Condition;
import edu.udo.cs.yale.example.ConditionedExampleSet;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.IOObject;
import edu.udo.cs.yale.operator.Model;
import edu.udo.cs.yale.operator.Operator;
import edu.udo.cs.yale.operator.OperatorDescription;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.UserError;
import edu.udo.cs.yale.operator.Value;
import edu.udo.cs.yale.operator.learner.LearnerCapability;
import edu.udo.cs.yale.operator.learner.PredictionModel;
import edu.udo.cs.yale.operator.learner.meta.AbstractMetaLearner;
import edu.udo.cs.yale.operator.learner.meta.BayBoostBaseModelInfo;
import edu.udo.cs.yale.operator.learner.meta.BayBoostModel;
import edu.udo.cs.yale.operator.learner.meta.ContingencyMatrix;
import edu.udo.cs.yale.operator.learner.meta.WeightedPerformanceMeasures;
import edu.udo.cs.yale.operator.parameter.ParameterType;
import edu.udo.cs.yale.operator.parameter.ParameterTypeBoolean;
import edu.udo.cs.yale.operator.parameter.ParameterTypeDouble;
import edu.udo.cs.yale.operator.parameter.ParameterTypeInt;
import edu.udo.cs.yale.operator.parameter.UndefinedParameterError;
import edu.udo.cs.yale.operator.performance.EstimatedPerformance;
import edu.udo.cs.yale.operator.performance.PerformanceVector;
import edu.udo.cs.yale.tools.LogService;
import edu.udo.cs.yale.tools.RandomGenerator;
import edu.udo.cs.yale.tools.Tools;
import edu.udo.cs.yale.tools.math.RunVector;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class BayBoostStream
extends AbstractMetaLearner {
    public static final String BATCH_SIZE = "batch_size";
    public static final String EQUALLY_PROB_LABELS = "rescale_label_priors";
    public static final String HOLD_OUT_RATIO = "fraction_hold_out_set";
    public static final double MIN_ADVANTAGE = 0.02;
    public static final String STREAM_CONTROL_ATTRIB_NAME = "BayBoostStream.StreamControl";
    public static final double MIN_LIFT_RATIO_SOFT_CLASSIFIER = 0.2;
    private RunVector runVector;
    private int currentIteration;
    private double performance = 0.0;
    private double[] oldWeights;

    public BayBoostStream(OperatorDescription description) {
        super(description);
        this.addValue(new Value("performance", "The performance."){

            public double getValue() {
                return BayBoostStream.this.performance;
            }
        });
        this.addValue(new Value("iteration", "The current iteration."){

            public double getValue() {
                return BayBoostStream.this.currentIteration;
            }
        });
    }

    @Override
    public boolean supportsCapability(LearnerCapability lc) {
        if (lc == LearnerCapability.NUMERICAL_CLASS) {
            return false;
        }
        return super.supportsCapability(lc);
    }

    protected void prepareWeights(ExampleSet exampleSet) {
        Attribute weightAttr = exampleSet.getAttributes().getWeight();
        if (weightAttr == null) {
            this.oldWeights = null;
            edu.udo.cs.yale.example.Tools.createWeightAttribute(exampleSet);
        } else {
            this.oldWeights = new double[exampleSet.size()];
            Iterator reader = exampleSet.iterator();
            int i = 0;
            while (reader.hasNext() && i < this.oldWeights.length) {
                Example example = (Example)reader.next();
                if (example != null) {
                    this.oldWeights[i] = example.getWeight();
                    example.setWeight(1.0);
                }
                ++i;
            }
        }
    }

    private void restoreOldWeights(ExampleSet exampleSet) {
        if (this.oldWeights != null) {
            Iterator reader = exampleSet.iterator();
            int i = 0;
            while (reader.hasNext() && i < this.oldWeights.length) {
                ((Example)reader.next()).setWeight(this.oldWeights[i++]);
            }
        } else {
            Attribute weight = exampleSet.getAttributes().getWeight();
            exampleSet.getAttributes().remove(weight);
            exampleSet.getExampleTable().removeAttribute(weight);
        }
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        Attribute streamControlAttribute;
        this.checkLearnerCapabilities(exampleSet);
        this.runVector = new RunVector();
        PredictionModel ensembleNewBatch = null;
        PredictionModel ensembleExtBatch = null;
        Vector<BayBoostBaseModelInfo> modelInfo = new Vector<BayBoostBaseModelInfo>();
        Vector<BayBoostBaseModelInfo> modelInfo2 = new Vector<BayBoostBaseModelInfo>();
        this.currentIteration = 0;
        int firstOpenBatch = 1;
        Attribute attr = null;
        attr = exampleSet.getAttributes().get(STREAM_CONTROL_ATTRIB_NAME);
        if (attr == null) {
            streamControlAttribute = edu.udo.cs.yale.example.Tools.createSpecialAttribute(exampleSet, STREAM_CONTROL_ATTRIB_NAME, 3);
        } else {
            streamControlAttribute = attr;
            LogService.logMessage("[Warning] BayBoostStream operator: Attribute with the (reserved) name of the stream control attribute exists. It is probably an old version created by this operator. Trying to recycle it... ", 4);
            Iterator e = exampleSet.iterator();
            while (e.hasNext()) {
                ((Example)e.next()).setValue(streamControlAttribute, 0.0);
            }
        }
        if (exampleSet.getAttributes().getWeight() == null) {
            this.prepareWeights(exampleSet);
        }
        boolean estimateFavoursExtBatch = true;
        Iterator<Example> reader = exampleSet.iterator();
        while (reader.hasNext()) {
            EstimatedPerformance estPerf;
            double[] classPriors = this.prepareBatch(++this.currentIteration, reader, streamControlAttribute);
            ConditionedExampleSet trainingSet = new ConditionedExampleSet(exampleSet, new BatchFilterCondition(streamControlAttribute, this.currentIteration));
            if (ensembleExtBatch != null) {
                double[] ensembleWeights;
                ensembleExtBatch.apply(trainingSet);
                this.performance = this.evaluatePredictions(trainingSet);
                ensembleNewBatch.apply(trainingSet);
                double newBatchPerformance = this.evaluatePredictions(trainingSet);
                estPerf = estimateFavoursExtBatch ? new EstimatedPerformance("accuracy", this.performance, trainingSet.size(), false) : new EstimatedPerformance("accuracy", newBatchPerformance, trainingSet.size(), false);
                System.out.print("Estimated acc for < new batch: " + newBatchPerformance + " > - < extended batch: " + this.performance + " > ==> ");
                if (newBatchPerformance > this.performance) {
                    System.out.println("Starting new batch");
                    this.performance = newBatchPerformance;
                    firstOpenBatch = Math.max(1, this.currentIteration - 1);
                    ensembleWeights = ((BayBoostModel)ensembleNewBatch).getModelWeights();
                } else {
                    modelInfo.clear();
                    modelInfo.addAll(modelInfo2);
                    System.out.println("Extending batch. Last model used batches " + firstOpenBatch + " to " + (this.currentIteration - 1));
                    ensembleWeights = ((BayBoostModel)ensembleExtBatch).getModelWeights();
                }
                System.out.print("[Weights of ensemble models:");
                int i = 0;
                while (i < ensembleWeights.length) {
                    System.out.print(" [" + i + ": " + ensembleWeights[i] + "]");
                    ++i;
                }
                System.out.println("]");
            } else if (ensembleNewBatch != null) {
                ensembleNewBatch.apply(trainingSet);
                this.performance = this.evaluatePredictions(trainingSet);
                firstOpenBatch = Math.max(1, this.currentIteration - 1);
                estPerf = new EstimatedPerformance("accuracy", this.performance, trainingSet.size(), false);
            } else {
                estPerf = null;
            }
            if (estPerf != null) {
                PerformanceVector perf = new PerformanceVector();
                perf.addAveragable(estPerf);
                this.runVector.addVector(perf);
            }
            if (this.getParameterAsBoolean(EQUALLY_PROB_LABELS)) {
                this.rescalePriors(trainingSet, classPriors);
            }
            estimateFavoursExtBatch = true;
            if (modelInfo.size() > 0) {
                boolean trainingExamplesLeft;
                modelInfo2 = new Vector();
                for (BayBoostBaseModelInfo bbbmi : modelInfo) {
                    modelInfo2.add(bbbmi);
                }
                double holdOutRatio = this.getParameterAsDouble(HOLD_OUT_RATIO);
                Vector<Example> holdOutExamples = new Vector<Example>();
                if (holdOutRatio > 0.0) {
                    RandomGenerator rg = RandomGenerator.getRandomGenerator(0);
                    for (Example example : trainingSet) {
                        if (!(rg.nextDoubleInRange(0.0, 1.0) <= holdOutRatio)) continue;
                        example.setValue(streamControlAttribute, 0.0);
                        holdOutExamples.add(example);
                    }
                    trainingSet.updateCondition();
                }
                if ((trainingExamplesLeft = this.adjustBaseModelWeights(trainingSet, modelInfo)) && !this.trainAdditionalModel(trainingSet, modelInfo)) {
                    System.out.println("Model for new single batch discarded!");
                }
                ensembleNewBatch = new BayBoostModel(exampleSet.getAttributes().getLabel(), modelInfo, classPriors);
                ConditionedExampleSet extendedBatch = new ConditionedExampleSet(exampleSet, new BatchFilterCondition(streamControlAttribute, firstOpenBatch));
                classPriors = this.prepareExtendedBatch(extendedBatch);
                if (this.getParameterAsBoolean(EQUALLY_PROB_LABELS)) {
                    this.rescalePriors(extendedBatch, classPriors);
                }
                modelInfo2.remove(modelInfo2.size() - 1);
                trainingExamplesLeft = this.adjustBaseModelWeights(extendedBatch, modelInfo2);
                if (!trainingExamplesLeft) {
                    ensembleExtBatch = new BayBoostModel(exampleSet.getAttributes().getLabel(), modelInfo2, classPriors);
                } else {
                    boolean success = this.trainAdditionalModel(extendedBatch, modelInfo2);
                    if (success) {
                        ensembleExtBatch = new BayBoostModel(exampleSet.getAttributes().getLabel(), modelInfo2, classPriors);
                    } else {
                        System.out.println("Model for extended batch discarded!");
                        ensembleExtBatch = null;
                        estimateFavoursExtBatch = false;
                    }
                }
                if (!(holdOutRatio > 0.0)) continue;
                Iterator hoEit = holdOutExamples.iterator();
                while (hoEit.hasNext()) {
                    ((Example)hoEit.next()).setValue(streamControlAttribute, this.currentIteration);
                }
                trainingSet.updateCondition();
                if (ensembleExtBatch != null) {
                    ensembleNewBatch.apply(trainingSet);
                    hoEit = holdOutExamples.iterator();
                    int errors = 0;
                    while (hoEit.hasNext()) {
                        Example example = (Example)hoEit.next();
                        if (example.getPredictedLabel() == example.getLabel()) continue;
                        ++errors;
                    }
                    double newBatchErr = (double)errors / (double)holdOutExamples.size();
                    ensembleExtBatch.apply(trainingSet);
                    hoEit = holdOutExamples.iterator();
                    errors = 0;
                    while (hoEit.hasNext()) {
                        Example example = (Example)hoEit.next();
                        if (example.getPredictedLabel() == example.getLabel()) continue;
                        ++errors;
                    }
                    double extBatchErr = (double)errors / (double)holdOutExamples.size();
                    System.out.println("[Hold out estimates: <newBatch: " + (1.0 - newBatchErr) + "> - <extBatch: " + (1.0 - extBatchErr) + ">]");
                    boolean bl = estimateFavoursExtBatch = extBatchErr <= newBatchErr;
                    if (estimateFavoursExtBatch) {
                        ensembleExtBatch = this.retrainLastWeight((BayBoostModel)ensembleExtBatch, trainingSet, holdOutExamples);
                        continue;
                    }
                    ensembleNewBatch = this.retrainLastWeight((BayBoostModel)ensembleNewBatch, trainingSet, holdOutExamples);
                    continue;
                }
                ensembleNewBatch = this.retrainLastWeight((BayBoostModel)ensembleNewBatch, trainingSet, holdOutExamples);
                continue;
            }
            this.trainAdditionalModel(trainingSet, modelInfo);
            ensembleNewBatch = new BayBoostModel(exampleSet.getAttributes().getLabel(), modelInfo, classPriors);
            ensembleExtBatch = null;
            estimateFavoursExtBatch = false;
        }
        this.restoreOldWeights(exampleSet);
        return ensembleExtBatch == null ? ensembleNewBatch : ensembleExtBatch;
    }

    private BayBoostModel retrainLastWeight(BayBoostModel ensemble, ExampleSet exampleSet, Vector holdOutSet) throws OperatorException {
        this.prepareExtendedBatch(exampleSet);
        int modelNum = ensemble.getNumberOfModels();
        Vector<BayBoostBaseModelInfo> modelInfo = new Vector<BayBoostBaseModelInfo>();
        double[] priors = ensemble.getPriors();
        int i = 0;
        while (i < modelNum - 1) {
            Model model = ensemble.getModel(i);
            ContingencyMatrix cm = ensemble.getContingencyMatrix(i);
            modelInfo.add(new BayBoostBaseModelInfo(model, cm));
            model.apply(exampleSet);
            WeightedPerformanceMeasures.reweightExamples(exampleSet, cm, false);
            ++i;
        }
        Model latestModel = ensemble.getModel(modelNum - 1);
        latestModel.apply(exampleSet);
        double[] weights = new double[holdOutSet.size()];
        Iterator it = holdOutSet.iterator();
        int index = 0;
        while (it.hasNext()) {
            Example example = (Example)it.next();
            weights[index++] = example.getWeight();
        }
        Iterator reader = exampleSet.iterator();
        while (reader.hasNext()) {
            ((Example)reader.next()).setWeight(0.0);
        }
        it = holdOutSet.iterator();
        index = 0;
        while (it.hasNext()) {
            Example example = (Example)it.next();
            example.setWeight(weights[index++]);
        }
        WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(exampleSet);
        modelInfo.add(new BayBoostBaseModelInfo(latestModel, wp.getContingencyMatrix()));
        return new BayBoostModel(exampleSet.getAttributes().getLabel(), modelInfo, priors);
    }

    @Override
    public IOObject[] apply() throws OperatorException {
        IOObject[] ioNew;
        IOObject[] ioOld = super.apply();
        if (ioOld != null) {
            ioNew = new IOObject[ioOld.length + 1];
            int i = 0;
            while (i < ioOld.length) {
                ioNew[i] = ioOld[i];
                ++i;
            }
        } else {
            ioNew = new IOObject[1];
        }
        ioNew[ioNew.length - 1] = this.runVector;
        return ioNew;
    }

    @Override
    public Class[] getOutputClasses() {
        Class[] classArray = super.getOutputClasses();
        Class[] classArrayNew = new Class[classArray.length + 1];
        int i = 0;
        while (i < classArray.length) {
            classArrayNew[i] = classArray[i];
            ++i;
        }
        classArrayNew[classArrayNew.length - 1] = RunVector.class;
        return classArrayNew;
    }

    private void rescalePriors(ExampleSet exampleSet, double[] classPriors) {
        double[] weights = new double[2];
        int i = 0;
        while (i < weights.length) {
            weights[i] = 1.0 / ((double)weights.length * classPriors[i]);
            ++i;
        }
        for (Example example : exampleSet) {
            example.setWeight(weights[(int)example.getLabel()]);
        }
    }

    private Model trainBaseModel(ExampleSet exampleSet) throws OperatorException {
        Model model = this.applyInnerLearner(exampleSet);
        BayBoostStream.createOrReplacePredictedLabelFor(exampleSet, model);
        model.apply(exampleSet);
        return model;
    }

    private double[] prepareBatch(int currentBatchNum, Iterator<Example> reader, Attribute batchAttribute) throws UndefinedParameterError {
        int batchSize = this.getParameterAsInt(BATCH_SIZE);
        int batchCount = 0;
        int[] classCount = new int[2];
        while (batchCount++ < batchSize && reader.hasNext()) {
            Example example = reader.next();
            example.setValue(batchAttribute, currentBatchNum);
            example.setWeight(1.0);
            int n = (int)example.getLabel();
            classCount[n] = classCount[n] + 1;
        }
        double[] classPriors = new double[]{(double)classCount[0] / (double)(--batchCount), (double)classCount[1] / (double)batchCount};
        return classPriors;
    }

    private double[] prepareExtendedBatch(ExampleSet extendedBatch) {
        int[] classCount = new int[2];
        for (Example example : extendedBatch) {
            example.setWeight(1.0);
            int n = (int)example.getLabel();
            classCount[n] = classCount[n] + 1;
        }
        double[] classPriors = new double[2];
        int sum = classCount[0] + classCount[1];
        classPriors[0] = (double)classCount[0] / (double)sum;
        classPriors[1] = (double)classCount[1] / (double)sum;
        return classPriors;
    }

    private double evaluatePredictions(ExampleSet exampleSet) {
        Iterator reader = exampleSet.iterator();
        int count = 0;
        int correct = 0;
        while (reader.hasNext()) {
            ++count;
            Example example = (Example)reader.next();
            if (example.getLabel() != example.getPredictedLabel()) continue;
            ++correct;
        }
        return (double)correct / (double)count;
    }

    private boolean trainAdditionalModel(ExampleSet trainingSet, Vector<BayBoostBaseModelInfo> modelInfo) throws OperatorException {
        Model model = this.trainBaseModel(trainingSet);
        WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(trainingSet);
        BayBoostStream.debugMessage(wp);
        if (!this.isModelUseful(wp.getContingencyMatrix())) {
            LogService.logMessage("Discard model because of low advantage on training data.", 2);
            return false;
        }
        modelInfo.add(new BayBoostBaseModelInfo(model, wp.getContingencyMatrix()));
        return true;
    }

    private boolean adjustBaseModelWeights(ExampleSet exampleSet, Vector<BayBoostBaseModelInfo> modelInfo) throws OperatorException {
        int j = 0;
        while (j < modelInfo.size()) {
            BayBoostBaseModelInfo consideredModelInfo = modelInfo.get(j);
            Model consideredModel = consideredModelInfo.getModel();
            ContingencyMatrix cm = consideredModelInfo.getContingencyMatrix();
            BayBoostStream.createOrReplacePredictedLabelFor(exampleSet, consideredModel);
            consideredModel.apply(exampleSet);
            if (!exampleSet.getAttributes().getPredictedLabel().isNominal()) {
                throw new UserError((Operator)this, 101, new Object[]{exampleSet.getAttributes().getLabel(), "BayBoostStream base learners"});
            }
            WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(exampleSet);
            ContingencyMatrix cmNew = wp.getContingencyMatrix();
            if (!this.isModelUseful(cm)) {
                modelInfo.remove(j);
                --j;
                LogService.logMessage("Discard base model because of low advantage.", 2);
            } else {
                boolean stillUncoveredExamples;
                consideredModelInfo = new BayBoostBaseModelInfo(consideredModel, cmNew);
                modelInfo.set(j, consideredModelInfo);
                boolean bl = stillUncoveredExamples = WeightedPerformanceMeasures.reweightExamples(exampleSet, cmNew, false) > 0.0;
                if (!stillUncoveredExamples) {
                    return false;
                }
            }
            ++j;
        }
        return true;
    }

    private boolean isModelUseful(ContingencyMatrix cm) {
        int row = 0;
        while (row < cm.getNumberOfPredictions()) {
            int col = 0;
            while (col < cm.getNumberOfClasses()) {
                if (Math.abs(cm.getLift(row, col) - 1.0) > 0.02) {
                    return true;
                }
                ++col;
            }
            ++row;
        }
        return false;
    }

    private static void debugMessage(WeightedPerformanceMeasures wp) {
        String message = String.valueOf(Tools.getLineSeparator()) + "Model learned - training performance of base learner:" + Tools.getLineSeparator() + "TPR: " + wp.getProbability(0, 0) + " FPR: " + wp.getProbability(1, 0) + " | Positively predicted: " + (wp.getProbability(1, 0) + wp.getProbability(0, 0)) + Tools.getLineSeparator() + "FNR: " + wp.getProbability(0, 1) + " TNR: " + wp.getProbability(1, 1) + " | Negatively predicted: " + (wp.getProbability(0, 1) + wp.getProbability(1, 1)) + Tools.getLineSeparator() + "Positively labelled: " + (wp.getProbability(0, 0) + wp.getProbability(0, 1)) + Tools.getLineSeparator() + "Negatively labelled: " + (wp.getProbability(1, 0) + wp.getProbability(1, 1));
        LogService.logMessage(message, 2);
    }

    private static void createOrReplacePredictedLabelFor(ExampleSet exampleSet, Model model) {
        Attribute predictedLabel = exampleSet.getAttributes().getPredictedLabel();
        if (predictedLabel != null) {
            exampleSet.getAttributes().remove(predictedLabel);
            exampleSet.getExampleTable().removeAttribute(predictedLabel);
        }
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeBoolean(EQUALLY_PROB_LABELS, "Specifies whether the proportion of labels should be equal by construction after first iteration .", false));
        types.add(new ParameterTypeInt(BATCH_SIZE, "Size of the batches. Minimum number of examples used to train a model.", 1, Integer.MAX_VALUE, 100));
        types.add(new ParameterTypeDouble(HOLD_OUT_RATIO, "Rel. size of hold out set for ensemble selection. Set to 0 to turn off.", 0.0, 1.0, 0.0));
        return types;
    }

    public static class BatchFilterCondition
    implements Condition {
        private final int batchNumber;
        private final Attribute attribute;

        public BatchFilterCondition(Attribute attribute, int batchNumber) {
            this.batchNumber = batchNumber;
            this.attribute = attribute;
        }

        public boolean conditionOk(Example example) {
            return example.getValue(this.attribute) >= (double)this.batchNumber;
        }

        public Condition duplicate() {
            return this;
        }
    }
}

