/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.meta;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Condition;
import edu.udo.cs.yale.example.ConditionedExampleSet;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleReader;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.IOObject;
import edu.udo.cs.yale.operator.Operator;
import edu.udo.cs.yale.operator.OperatorDescription;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.UserError;
import edu.udo.cs.yale.operator.Value;
import edu.udo.cs.yale.operator.learner.LearnerCapability;
import edu.udo.cs.yale.operator.learner.Model;
import edu.udo.cs.yale.operator.learner.meta.AbstractMetaLearner;
import edu.udo.cs.yale.operator.learner.meta.BayBoostModel;
import edu.udo.cs.yale.operator.learner.meta.WeightedPerformanceMeasures;
import edu.udo.cs.yale.operator.parameter.ParameterTypeBoolean;
import edu.udo.cs.yale.operator.parameter.ParameterTypeDouble;
import edu.udo.cs.yale.operator.parameter.ParameterTypeInt;
import edu.udo.cs.yale.operator.performance.EstimatedPerformance;
import edu.udo.cs.yale.operator.performance.PerformanceVector;
import edu.udo.cs.yale.tools.LogService;
import edu.udo.cs.yale.tools.RandomGenerator;
import edu.udo.cs.yale.tools.math.RunVector;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;

public class BayBoostStream
extends AbstractMetaLearner {
    public static final String BATCH_SIZE = "batch_size";
    public static final String EQUALLY_PROB_LABELS = "rescale_label_priors";
    public static final String HOLD_OUT_RATIO = "fraction_hold_out_set";
    public static final double MIN_ADVANTAGE = 0.02;
    public static final String STREAM_CONTROL_ATTRIB_NAME = "BayBoostStream.StreamControl";
    public static final double MIN_LIFT_RATIO_SOFT_CLASSIFIER = 0.2;
    private RunVector runVector;
    private int currentIteration;
    private double performance = 0.0;
    static /* synthetic */ Class class$edu$udo$cs$yale$tools$math$RunVector;

    public BayBoostStream(OperatorDescription description) {
        super(description);
        this.addValue(new Value("performance", "The performance."){

            public double getValue() {
                return BayBoostStream.this.performance;
            }
        });
        this.addValue(new Value("iteration", "The current iteration."){

            public double getValue() {
                return BayBoostStream.this.currentIteration;
            }
        });
    }

    public boolean supportsCapability(LearnerCapability lc) {
        if (lc == LearnerCapability.POLYNOMINAL_CLASS || lc == LearnerCapability.NUMERICAL_CLASS) {
            return false;
        }
        return super.supportsCapability(lc);
    }

    public List getParameterTypes() {
        List types = super.getParameterTypes();
        types.add(new ParameterTypeBoolean(EQUALLY_PROB_LABELS, "Specifies whether the proportion of labels should be equal by construction after first iteration .", false));
        types.add(new ParameterTypeInt(BATCH_SIZE, "Size of the batches. Minimum number of examples used to train a model.", 1, Integer.MAX_VALUE, 100));
        types.add(new ParameterTypeDouble(HOLD_OUT_RATIO, "Rel. size of hold out set for ensemble selection. Set to 0 to turn off.", 0.0, 1.0, 0.0));
        return types;
    }

    public int getNumberOfSteps() {
        return 1;
    }

    public Model learn(ExampleSet exampleSet) throws OperatorException {
        Attribute streamControlAttribute;
        this.checkLearnerCapabilities(exampleSet);
        if (exampleSet.getLabel().getValues().size() != 2) {
            throw new UserError((Operator)this, 118, new Object[]{exampleSet.getLabel(), new Integer(exampleSet.getLabel().getValues().size()), new Integer(2)});
        }
        this.runVector = new RunVector();
        BayBoostModel ensembleNewBatch = null;
        BayBoostModel ensembleExtBatch = null;
        Vector modelInfo = new Vector();
        Vector modelInfo2 = null;
        this.currentIteration = 0;
        int firstOpenBatch = 1;
        Attribute attr = null;
        attr = exampleSet.getAttribute(STREAM_CONTROL_ATTRIB_NAME);
        if (attr == null) {
            streamControlAttribute = exampleSet.createSpecialAttribute(STREAM_CONTROL_ATTRIB_NAME, 3);
        } else {
            streamControlAttribute = attr;
            LogService.logMessage("[Warning] BayBoostStream operator: Attribute with the (reserved) name of the stream control attribute exists. It is probably an old version created by this operator. Trying to recycle it... ", 4);
            ExampleReader e = exampleSet.getExampleReader();
            while (e.hasNext()) {
                e.next().setValue(streamControlAttribute, 0.0);
            }
        }
        if (exampleSet.getWeight() == null) {
            exampleSet.createWeightAttribute();
        }
        boolean estimateFavoursExtBatch = true;
        ExampleReader reader = exampleSet.getExampleReader();
        while (reader.hasNext()) {
            EstimatedPerformance estPerf;
            double[] classPriors = this.prepareBatch(++this.currentIteration, reader, streamControlAttribute);
            ConditionedExampleSet trainingSet = new ConditionedExampleSet(exampleSet, new BatchFilterCondition(streamControlAttribute, this.currentIteration));
            if (ensembleExtBatch != null) {
                double[] ensembleWeights;
                ensembleExtBatch.apply(trainingSet);
                this.performance = this.evaluatePredictions(trainingSet);
                ensembleNewBatch.apply(trainingSet);
                double newBatchPerformance = this.evaluatePredictions(trainingSet);
                estPerf = estimateFavoursExtBatch ? new EstimatedPerformance("accuracy", this.performance, trainingSet.getSize(), false) : new EstimatedPerformance("accuracy", newBatchPerformance, trainingSet.getSize(), false);
                System.out.print("Estimated acc for < new batch: " + newBatchPerformance + " > - < extended batch: " + this.performance + " > ==> ");
                if (newBatchPerformance > this.performance) {
                    System.out.println("Starting new batch");
                    this.performance = newBatchPerformance;
                    firstOpenBatch = Math.max(1, this.currentIteration - 1);
                    ensembleWeights = ensembleNewBatch.getModelWeights();
                } else {
                    modelInfo.clear();
                    modelInfo.addAll(modelInfo2);
                    System.out.println("Extending batch. Last model used batches " + firstOpenBatch + " to " + (this.currentIteration - 1));
                    ensembleWeights = ensembleExtBatch.getModelWeights();
                }
                System.out.print("[Weights of ensemble models:");
                for (int i = 0; i < ensembleWeights.length; ++i) {
                    System.out.print(" [" + i + ": " + ensembleWeights[i] + "]");
                }
                System.out.println("]");
            } else if (ensembleNewBatch != null) {
                ensembleNewBatch.apply(trainingSet);
                this.performance = this.evaluatePredictions(trainingSet);
                firstOpenBatch = Math.max(1, this.currentIteration - 1);
                estPerf = new EstimatedPerformance("accuracy", this.performance, trainingSet.getSize(), false);
            } else {
                estPerf = null;
            }
            if (estPerf != null) {
                PerformanceVector perf = new PerformanceVector();
                perf.addAveragable(estPerf);
                this.runVector.addVector(perf);
            }
            if (this.getParameterAsBoolean(EQUALLY_PROB_LABELS)) {
                this.rescalePriors(trainingSet, classPriors);
            }
            estimateFavoursExtBatch = true;
            if (modelInfo.size() > 0) {
                boolean trainingExamplesLeft;
                modelInfo2 = (Vector)modelInfo.clone();
                Iterator it = modelInfo2.iterator();
                while (it.hasNext()) {
                    Object[] modInf = (Object[])it.next();
                    double[][] biasMatrix = (double[][])modInf[1];
                    double[][] newBiasMatrix = new double[biasMatrix.length][];
                    for (int p = 0; p < biasMatrix.length; ++p) {
                        newBiasMatrix[p] = (double[])biasMatrix[p].clone();
                    }
                    modInf[1] = newBiasMatrix;
                }
                double holdOutRatio = this.getParameterAsDouble(HOLD_OUT_RATIO);
                Vector<Example> holdOutExamples = new Vector<Example>();
                if (holdOutRatio > 0.0) {
                    RandomGenerator rg = RandomGenerator.getGlobalRandomGenerator();
                    ExampleReader randBatchReader = trainingSet.getExampleReader();
                    while (randBatchReader.hasNext()) {
                        Example example = randBatchReader.next();
                        if (!(rg.nextDoubleInRange(0.0, 1.0) <= holdOutRatio)) continue;
                        example.setValue(streamControlAttribute, 0.0);
                        holdOutExamples.add(example);
                    }
                    trainingSet.updateCondition();
                }
                if ((trainingExamplesLeft = this.adjustBaseModelWeights(trainingSet, modelInfo)) && !this.trainAdditionalModel(trainingSet, modelInfo)) {
                    System.out.println("Model for new single batch discarded!");
                }
                ensembleNewBatch = new BayBoostModel(exampleSet.getLabel(), modelInfo, classPriors, true);
                ConditionedExampleSet extendedBatch = new ConditionedExampleSet(exampleSet, new BatchFilterCondition(streamControlAttribute, firstOpenBatch));
                classPriors = this.prepareExtendedBatch(extendedBatch);
                if (this.getParameterAsBoolean(EQUALLY_PROB_LABELS)) {
                    this.rescalePriors(extendedBatch, classPriors);
                }
                modelInfo2.remove(modelInfo2.size() - 1);
                trainingExamplesLeft = this.adjustBaseModelWeights(extendedBatch, modelInfo2);
                if (!trainingExamplesLeft) {
                    ensembleExtBatch = new BayBoostModel(exampleSet.getLabel(), modelInfo2, classPriors, true);
                } else {
                    boolean success = this.trainAdditionalModel(extendedBatch, modelInfo2);
                    if (success) {
                        ensembleExtBatch = new BayBoostModel(exampleSet.getLabel(), modelInfo2, classPriors, true);
                    } else {
                        System.out.println("Model for extended batch discarded!");
                        ensembleExtBatch = null;
                        estimateFavoursExtBatch = false;
                    }
                }
                if (!(holdOutRatio > 0.0)) continue;
                Iterator hoEit = holdOutExamples.iterator();
                while (hoEit.hasNext()) {
                    ((Example)hoEit.next()).setValue(streamControlAttribute, this.currentIteration);
                }
                trainingSet.updateCondition();
                if (ensembleExtBatch != null) {
                    ensembleNewBatch.apply(trainingSet);
                    hoEit = holdOutExamples.iterator();
                    int errors = 0;
                    while (hoEit.hasNext()) {
                        Example example = (Example)hoEit.next();
                        if (example.getPredictedLabel() == example.getLabel()) continue;
                        ++errors;
                    }
                    double newBatchErr = (double)errors / (double)holdOutExamples.size();
                    ensembleExtBatch.apply(trainingSet);
                    hoEit = holdOutExamples.iterator();
                    errors = 0;
                    while (hoEit.hasNext()) {
                        Example example = (Example)hoEit.next();
                        if (example.getPredictedLabel() == example.getLabel()) continue;
                        ++errors;
                    }
                    double extBatchErr = (double)errors / (double)holdOutExamples.size();
                    System.out.println("[Hold out estimates: <newBatch: " + (1.0 - newBatchErr) + "> - <extBatch: " + (1.0 - extBatchErr) + ">]");
                    boolean bl = estimateFavoursExtBatch = extBatchErr <= newBatchErr;
                    if (estimateFavoursExtBatch) {
                        ensembleExtBatch = this.retrainLastWeight(ensembleExtBatch, trainingSet, holdOutExamples);
                        continue;
                    }
                    ensembleNewBatch = this.retrainLastWeight(ensembleNewBatch, trainingSet, holdOutExamples);
                    continue;
                }
                ensembleNewBatch = this.retrainLastWeight(ensembleNewBatch, trainingSet, holdOutExamples);
                continue;
            }
            this.trainAdditionalModel(trainingSet, modelInfo);
            ensembleNewBatch = new BayBoostModel(exampleSet.getLabel(), modelInfo, classPriors, true);
            ensembleExtBatch = null;
            estimateFavoursExtBatch = false;
        }
        return ensembleExtBatch == null ? ensembleNewBatch : ensembleExtBatch;
    }

    private BayBoostModel retrainLastWeight(BayBoostModel ensemble, ExampleSet exampleSet, Vector holdOutSet) throws OperatorException {
        this.prepareExtendedBatch(exampleSet);
        int modelNum = ensemble.getNumberOfModels();
        Vector<Object[]> modelInfo = new Vector<Object[]>();
        double[] priors = ensemble.getPriors();
        for (int i = 0; i < modelNum - 1; ++i) {
            Model model = ensemble.getModel(i);
            double[][] biasMatrix = ensemble.getBiasMatrix(i);
            modelInfo.add(new Object[]{model, biasMatrix});
            model.apply(exampleSet);
            WeightedPerformanceMeasures.reweightExamples(exampleSet, biasMatrix, priors);
        }
        Model latestModel = ensemble.getModel(modelNum - 1);
        latestModel.apply(exampleSet);
        double[] weights = new double[holdOutSet.size()];
        Iterator it = holdOutSet.iterator();
        int index = 0;
        while (it.hasNext()) {
            Example example = (Example)it.next();
            weights[index++] = example.getWeight();
        }
        ExampleReader reader = exampleSet.getExampleReader();
        while (reader.hasNext()) {
            reader.next().setWeight(0.0);
        }
        it = holdOutSet.iterator();
        index = 0;
        while (it.hasNext()) {
            Example example = (Example)it.next();
            example.setWeight(weights[index++]);
        }
        WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(exampleSet);
        double[][] biasMatrix = wp.createBiasMatrix();
        modelInfo.add(new Object[]{latestModel, biasMatrix});
        return new BayBoostModel(exampleSet.getLabel(), modelInfo, priors, true);
    }

    public IOObject[] apply() throws OperatorException {
        IOObject[] ioNew;
        IOObject[] ioOld = super.apply();
        if (ioOld != null) {
            ioNew = new IOObject[ioOld.length + 1];
            for (int i = 0; i < ioOld.length; ++i) {
                ioNew[i] = ioOld[i];
            }
        } else {
            ioNew = new IOObject[1];
        }
        ioNew[ioNew.length - 1] = this.runVector;
        return ioNew;
    }

    public Class[] getOutputClasses() {
        Class[] classArray = super.getOutputClasses();
        Class[] classArrayNew = new Class[classArray.length + 1];
        for (int i = 0; i < classArray.length; ++i) {
            classArrayNew[i] = classArray[i];
        }
        classArrayNew[classArrayNew.length - 1] = class$edu$udo$cs$yale$tools$math$RunVector == null ? (class$edu$udo$cs$yale$tools$math$RunVector = BayBoostStream.class$("edu.udo.cs.yale.tools.math.RunVector")) : class$edu$udo$cs$yale$tools$math$RunVector;
        return classArrayNew;
    }

    private void rescalePriors(ExampleSet exampleSet, double[] classPriors) {
        double[] weights = new double[2];
        for (int i = 0; i < weights.length; ++i) {
            weights[i] = 1.0 / ((double)weights.length * classPriors[i]);
        }
        ExampleReader exRead = exampleSet.getExampleReader();
        while (exRead.hasNext()) {
            Example example = exRead.next();
            example.setWeight(weights[(int)example.getLabel()]);
        }
    }

    private Model trainBaseModel(ExampleSet exampleSet) throws OperatorException {
        Model model = this.applyInnerLearner(exampleSet);
        BayBoostStream.createOrReplacePredictedLabelFor(exampleSet, model);
        model.apply(exampleSet);
        return model;
    }

    private double[] prepareBatch(int currentBatchNum, ExampleReader reader, Attribute batchAttribute) {
        int batchSize = this.getParameterAsInt(BATCH_SIZE);
        int batchCount = 0;
        int[] classCount = new int[2];
        while (batchCount++ < batchSize && reader.hasNext()) {
            Example example = reader.next();
            example.setValue(batchAttribute, currentBatchNum);
            example.setWeight(1.0);
            int n = (int)example.getLabel();
            classCount[n] = classCount[n] + 1;
        }
        double[] classPriors = new double[]{(double)classCount[0] / (double)(--batchCount), (double)classCount[1] / (double)batchCount};
        return classPriors;
    }

    private double[] prepareExtendedBatch(ExampleSet extendedBatch) {
        int[] classCount = new int[2];
        ExampleReader reader = extendedBatch.getExampleReader();
        while (reader.hasNext()) {
            Example example = reader.next();
            example.setWeight(1.0);
            int n = (int)example.getLabel();
            classCount[n] = classCount[n] + 1;
        }
        double[] classPriors = new double[2];
        int sum = classCount[0] + classCount[1];
        classPriors[0] = (double)classCount[0] / (double)sum;
        classPriors[1] = (double)classCount[1] / (double)sum;
        return classPriors;
    }

    private double evaluatePredictions(ExampleSet exampleSet) {
        ExampleReader reader = exampleSet.getExampleReader();
        int count = 0;
        int correct = 0;
        while (reader.hasNext()) {
            ++count;
            Example example = reader.next();
            if (example.getLabel() != example.getPredictedLabel()) continue;
            ++correct;
        }
        return (double)correct / (double)count;
    }

    private boolean trainAdditionalModel(ExampleSet trainingSet, Vector modelInfo) throws OperatorException {
        Model model = this.trainBaseModel(trainingSet);
        WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(trainingSet);
        BayBoostStream.debugMessage(wp);
        double[][] biasMatrix = wp.createBiasMatrix();
        if (!this.isModelUseful(biasMatrix)) {
            LogService.logMessage("Discard model because of low advantage on training data.", 2);
            return false;
        }
        modelInfo.add(new Object[]{model, biasMatrix});
        return true;
    }

    private boolean adjustBaseModelWeights(ExampleSet exampleSet, Vector modelInfo) throws OperatorException {
        for (int j = 0; j < modelInfo.size(); ++j) {
            Object[] consideredModelInfo = (Object[])modelInfo.get(j);
            Model consideredModel = (Model)consideredModelInfo[0];
            double[][] oldBiasMatrix = (double[][])consideredModelInfo[1];
            BayBoostStream.createOrReplacePredictedLabelFor(exampleSet, consideredModel);
            if (!exampleSet.getPredictedLabel().isNominal()) {
                throw new UserError((Operator)this, 101, new Object[]{exampleSet.getLabel(), "BayBoostStream base learners"});
            }
            consideredModel.apply(exampleSet);
            WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(exampleSet);
            double[][] newBiasMatrix = wp.createBiasMatrix();
            if (!this.isModelUseful(newBiasMatrix)) {
                modelInfo.remove(j);
                --j;
                LogService.logMessage("Discard base model because of low advantage.", 2);
                continue;
            }
            consideredModelInfo[1] = newBiasMatrix;
            LogService.logMessage("Weights of model number " + j + " updated (old/new):", 2);
            for (int m = 0; m < newBiasMatrix.length; ++m) {
                String line = "";
                for (int n = 0; n < newBiasMatrix[m].length; ++n) {
                    line = line + "(" + oldBiasMatrix[m][n] + "/" + newBiasMatrix[m][n] + ") ";
                }
                LogService.logMessage(line, 2);
            }
            boolean stillUncoveredExamples = wp.reweightExamples(exampleSet);
            if (stillUncoveredExamples) continue;
            return false;
        }
        return true;
    }

    private boolean isModelUseful(double[][] biasMatrix) {
        for (int row = 0; row < biasMatrix.length; ++row) {
            double[] current = biasMatrix[row];
            for (int col = 0; col < current.length; ++col) {
                if (!(Math.abs(current[col] - 1.0) > 0.02)) continue;
                return true;
            }
        }
        return false;
    }

    private static void debugMessage(WeightedPerformanceMeasures wp) {
        String message = "\nModel learned - training performance of base learner:\nTPR: " + wp.getProbability(0, 0) + " FPR: " + wp.getProbability(1, 0) + " | Positively predicted: " + (wp.getProbability(1, 0) + wp.getProbability(0, 0)) + "\nFNR: " + wp.getProbability(0, 1) + " TNR: " + wp.getProbability(1, 1) + " | Negatively predicted: " + (wp.getProbability(0, 1) + wp.getProbability(1, 1)) + "\nPositively labelled: " + (wp.getProbability(0, 0) + wp.getProbability(0, 1)) + "\nNegatively labelled: " + (wp.getProbability(1, 0) + wp.getProbability(1, 1));
        LogService.logMessage(message, 2);
    }

    private static void createOrReplacePredictedLabelFor(ExampleSet exampleSet, Model model) {
        Attribute predictedLabel = exampleSet.getPredictedLabel();
        if (predictedLabel != null) {
            exampleSet.clearPredictedLabel();
            exampleSet.getExampleTable().removeAttribute(predictedLabel);
        }
        model.createPredictedLabel(exampleSet);
    }

    static /* synthetic */ Class class$(String x0) {
        try {
            return Class.forName(x0);
        }
        catch (ClassNotFoundException x1) {
            throw new NoClassDefFoundError(x1.getMessage());
        }
    }

    public class BatchFilterCondition
    implements Condition {
        private final int batchNumber;
        private final Attribute attribute;

        public BatchFilterCondition(Attribute attribute, int batchNumber) {
            this.batchNumber = batchNumber;
            this.attribute = attribute;
        }

        public boolean conditionOk(Example example) {
            return example.getValue(this.attribute) >= (double)this.batchNumber;
        }

        public Condition duplicate() {
            return this;
        }
    }
}

