/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.example.set.Condition;
import com.rapidminer.example.set.ConditionedExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.meta.AbstractMetaLearner;
import com.rapidminer.operator.learner.meta.BayBoostBaseModelInfo;
import com.rapidminer.operator.learner.meta.BayBoostModel;
import com.rapidminer.operator.learner.meta.ContingencyMatrix;
import com.rapidminer.operator.learner.meta.WeightedPerformanceMeasures;
import com.rapidminer.operator.performance.EstimatedPerformance;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.GenerateNewMDRule;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.math.RunVector;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;

public class BayBoostStream
extends AbstractMetaLearner {
    private OutputPort runVectorOutput = (OutputPort)this.getOutputPorts().createPort("run vector");
    public static final String PARAMETER_BATCH_SIZE = "batch_size";
    public static final String PARAMETER_RESCALE_LABEL_PRIORS = "rescale_label_priors";
    public static final String PARAMETER_FRACTION_HOLD_OUT_SET = "fraction_hold_out_set";
    public static final double MIN_ADVANTAGE = 0.02;
    public static final String STREAM_CONTROL_ATTRIB_NAME = "BayBoostStream.StreamControl";
    public static final double MIN_LIFT_RATIO_SOFT_CLASSIFIER = 0.2;
    private RunVector runVector;
    private int currentIteration;
    private double performance = 0.0;
    private double[] oldWeights;

    public BayBoostStream(OperatorDescription description) {
        super(description);
        this.getTransformer().addRule(new GenerateNewMDRule(this.runVectorOutput, RunVector.class));
        this.addValue(new ValueDouble("performance", "The performance."){

            @Override
            public double getDoubleValue() {
                return BayBoostStream.this.performance;
            }
        });
        this.addValue(new ValueDouble("iteration", "The current iteration."){

            @Override
            public double getDoubleValue() {
                return BayBoostStream.this.currentIteration;
            }
        });
    }

    @Override
    public boolean supportsCapability(OperatorCapability lc) {
        switch (lc) {
            case NUMERICAL_LABEL: 
            case NO_LABEL: 
            case UPDATABLE: 
            case FORMULA_PROVIDER: {
                return false;
            }
        }
        return true;
    }

    protected void prepareWeights(ExampleSet exampleSet) {
        Attribute weightAttr = exampleSet.getAttributes().getWeight();
        if (weightAttr == null) {
            this.oldWeights = null;
            Tools.createWeightAttribute(exampleSet);
        } else {
            this.oldWeights = new double[exampleSet.size()];
            Iterator reader = exampleSet.iterator();
            for (int i = 0; reader.hasNext() && i < this.oldWeights.length; ++i) {
                Example example = (Example)reader.next();
                if (example == null) continue;
                this.oldWeights[i] = example.getWeight();
                example.setWeight(1.0);
            }
        }
    }

    private void restoreOldWeights(ExampleSet exampleSet) {
        if (this.oldWeights != null) {
            Iterator reader = exampleSet.iterator();
            int i = 0;
            while (reader.hasNext() && i < this.oldWeights.length) {
                ((Example)reader.next()).setWeight(this.oldWeights[i++]);
            }
        } else {
            Attribute weight = exampleSet.getAttributes().getWeight();
            exampleSet.getAttributes().remove(weight);
            exampleSet.getExampleTable().removeAttribute(weight);
        }
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        Attribute streamControlAttribute;
        this.runVector = new RunVector();
        PredictionModel ensembleNewBatch = null;
        PredictionModel ensembleExtBatch = null;
        Vector<BayBoostBaseModelInfo> modelInfo = new Vector<BayBoostBaseModelInfo>();
        Vector<BayBoostBaseModelInfo> modelInfo2 = new Vector<BayBoostBaseModelInfo>();
        this.currentIteration = 0;
        int firstOpenBatch = 1;
        Attribute attr = null;
        attr = exampleSet.getAttributes().get(STREAM_CONTROL_ATTRIB_NAME);
        if (attr == null) {
            streamControlAttribute = Tools.createSpecialAttribute(exampleSet, STREAM_CONTROL_ATTRIB_NAME, 3);
        } else {
            streamControlAttribute = attr;
            this.logWarning("Attribute with the (reserved) name of the stream control attribute exists. It is probably an old version created by this operator. Trying to recycle it... ");
            Iterator e = exampleSet.iterator();
            while (e.hasNext()) {
                ((Example)e.next()).setValue(streamControlAttribute, 0.0);
            }
        }
        if (exampleSet.getAttributes().getWeight() == null) {
            this.prepareWeights(exampleSet);
        }
        boolean estimateFavoursExtBatch = true;
        Iterator<Example> reader = exampleSet.iterator();
        while (reader.hasNext()) {
            EstimatedPerformance estPerf;
            double[] classPriors = this.prepareBatch(++this.currentIteration, reader, streamControlAttribute);
            ConditionedExampleSet trainingSet = new ConditionedExampleSet(exampleSet, new BatchFilterCondition(streamControlAttribute, this.currentIteration));
            if (ensembleExtBatch != null) {
                trainingSet = (ConditionedExampleSet)ensembleExtBatch.apply(trainingSet);
                this.performance = this.evaluatePredictions(trainingSet);
                trainingSet = (ConditionedExampleSet)ensembleNewBatch.apply(trainingSet);
                double newBatchPerformance = this.evaluatePredictions(trainingSet);
                estPerf = estimateFavoursExtBatch ? new EstimatedPerformance("accuracy", this.performance, trainingSet.size(), false) : new EstimatedPerformance("accuracy", newBatchPerformance, trainingSet.size(), false);
                if (newBatchPerformance > this.performance) {
                    this.performance = newBatchPerformance;
                    firstOpenBatch = Math.max(1, this.currentIteration - 1);
                } else {
                    modelInfo.clear();
                    modelInfo.addAll(modelInfo2);
                }
            } else if (ensembleNewBatch != null) {
                trainingSet = (ConditionedExampleSet)ensembleNewBatch.apply(trainingSet);
                this.performance = this.evaluatePredictions(trainingSet);
                firstOpenBatch = Math.max(1, this.currentIteration - 1);
                estPerf = new EstimatedPerformance("accuracy", this.performance, trainingSet.size(), false);
            } else {
                estPerf = null;
            }
            if (estPerf != null) {
                PerformanceVector perf = new PerformanceVector();
                perf.addAveragable(estPerf);
                this.runVector.addVector(perf);
            }
            if (this.getParameterAsBoolean(PARAMETER_RESCALE_LABEL_PRIORS)) {
                this.rescalePriors(trainingSet, classPriors);
            }
            estimateFavoursExtBatch = true;
            if (modelInfo.size() > 0) {
                boolean trainingExamplesLeft;
                modelInfo2 = new Vector();
                for (BayBoostBaseModelInfo bbbmi : modelInfo) {
                    modelInfo2.add(bbbmi);
                }
                double holdOutRatio = this.getParameterAsDouble(PARAMETER_FRACTION_HOLD_OUT_SET);
                Vector<Example> holdOutExamples = new Vector<Example>();
                if (holdOutRatio > 0.0) {
                    RandomGenerator random = RandomGenerator.getRandomGenerator(this);
                    for (Example example : trainingSet) {
                        if (!(random.nextDoubleInRange(0.0, 1.0) <= holdOutRatio)) continue;
                        example.setValue(streamControlAttribute, 0.0);
                        holdOutExamples.add(example);
                    }
                }
                if (!(trainingExamplesLeft = this.adjustBaseModelWeights(trainingSet, modelInfo)) || !this.trainAdditionalModel(trainingSet, modelInfo)) {
                    // empty if block
                }
                ensembleNewBatch = new BayBoostModel(exampleSet, modelInfo, classPriors);
                ConditionedExampleSet extendedBatch = new ConditionedExampleSet(exampleSet, new BatchFilterCondition(streamControlAttribute, firstOpenBatch));
                classPriors = this.prepareExtendedBatch(extendedBatch);
                if (this.getParameterAsBoolean(PARAMETER_RESCALE_LABEL_PRIORS)) {
                    this.rescalePriors(extendedBatch, classPriors);
                }
                modelInfo2.remove(modelInfo2.size() - 1);
                trainingExamplesLeft = this.adjustBaseModelWeights(extendedBatch, modelInfo2);
                if (!trainingExamplesLeft) {
                    ensembleExtBatch = new BayBoostModel(exampleSet, modelInfo2, classPriors);
                } else {
                    boolean success = this.trainAdditionalModel(extendedBatch, modelInfo2);
                    if (success) {
                        ensembleExtBatch = new BayBoostModel(exampleSet, modelInfo2, classPriors);
                    } else {
                        ensembleExtBatch = null;
                        estimateFavoursExtBatch = false;
                    }
                }
                if (!(holdOutRatio > 0.0)) continue;
                Iterator hoEit = holdOutExamples.iterator();
                while (hoEit.hasNext()) {
                    ((Example)hoEit.next()).setValue(streamControlAttribute, this.currentIteration);
                }
                if (ensembleExtBatch != null) {
                    trainingSet = (ConditionedExampleSet)ensembleNewBatch.apply(trainingSet);
                    hoEit = holdOutExamples.iterator();
                    int errors = 0;
                    while (hoEit.hasNext()) {
                        Example example = (Example)hoEit.next();
                        if (example.getPredictedLabel() == example.getLabel()) continue;
                        ++errors;
                    }
                    double newBatchErr = (double)errors / (double)holdOutExamples.size();
                    trainingSet = (ConditionedExampleSet)ensembleExtBatch.apply(trainingSet);
                    hoEit = holdOutExamples.iterator();
                    errors = 0;
                    while (hoEit.hasNext()) {
                        Example example = (Example)hoEit.next();
                        if (example.getPredictedLabel() == example.getLabel()) continue;
                        ++errors;
                    }
                    double extBatchErr = (double)errors / (double)holdOutExamples.size();
                    boolean bl = estimateFavoursExtBatch = extBatchErr <= newBatchErr;
                    if (estimateFavoursExtBatch) {
                        ensembleExtBatch = this.retrainLastWeight((BayBoostModel)ensembleExtBatch, trainingSet, holdOutExamples);
                        continue;
                    }
                    ensembleNewBatch = this.retrainLastWeight((BayBoostModel)ensembleNewBatch, trainingSet, holdOutExamples);
                    continue;
                }
                ensembleNewBatch = this.retrainLastWeight((BayBoostModel)ensembleNewBatch, trainingSet, holdOutExamples);
                continue;
            }
            this.trainAdditionalModel(trainingSet, modelInfo);
            ensembleNewBatch = new BayBoostModel(exampleSet, modelInfo, classPriors);
            ensembleExtBatch = null;
            estimateFavoursExtBatch = false;
        }
        this.restoreOldWeights(exampleSet);
        return ensembleExtBatch == null ? ensembleNewBatch : ensembleExtBatch;
    }

    private BayBoostModel retrainLastWeight(BayBoostModel ensemble, ExampleSet exampleSet, Vector holdOutSet) throws OperatorException {
        this.prepareExtendedBatch(exampleSet);
        int modelNum = ensemble.getNumberOfModels();
        Vector<BayBoostBaseModelInfo> modelInfo = new Vector<BayBoostBaseModelInfo>();
        double[] priors = ensemble.getPriors();
        for (int i = 0; i < modelNum - 1; ++i) {
            Model model = ensemble.getModel(i);
            ContingencyMatrix cm = ensemble.getContingencyMatrix(i);
            modelInfo.add(new BayBoostBaseModelInfo(model, cm));
            exampleSet = model.apply(exampleSet);
            WeightedPerformanceMeasures.reweightExamples(exampleSet, cm, false);
        }
        Model latestModel = ensemble.getModel(modelNum - 1);
        exampleSet = latestModel.apply(exampleSet);
        double[] weights = new double[holdOutSet.size()];
        Iterator it = holdOutSet.iterator();
        int index = 0;
        while (it.hasNext()) {
            Example example = (Example)it.next();
            weights[index++] = example.getWeight();
        }
        Iterator reader = exampleSet.iterator();
        while (reader.hasNext()) {
            ((Example)reader.next()).setWeight(0.0);
        }
        it = holdOutSet.iterator();
        index = 0;
        while (it.hasNext()) {
            Example example = (Example)it.next();
            example.setWeight(weights[index++]);
        }
        WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(exampleSet);
        modelInfo.add(new BayBoostBaseModelInfo(latestModel, wp.getContingencyMatrix()));
        return new BayBoostModel(exampleSet, modelInfo, priors);
    }

    @Override
    public void doWork() throws OperatorException {
        super.doWork();
        this.runVectorOutput.deliver(this.runVector);
    }

    private void rescalePriors(ExampleSet exampleSet, double[] classPriors) {
        double[] weights = new double[2];
        for (int i = 0; i < weights.length; ++i) {
            weights[i] = 1.0 / ((double)weights.length * classPriors[i]);
        }
        for (Example example : exampleSet) {
            example.setWeight(weights[(int)example.getLabel()]);
        }
    }

    private Model trainBaseModel(ExampleSet exampleSet) throws OperatorException {
        Model model = this.applyInnerLearner(exampleSet);
        BayBoostStream.createOrReplacePredictedLabelFor(exampleSet, model);
        return model;
    }

    private double[] prepareBatch(int currentBatchNum, Iterator<Example> reader, Attribute batchAttribute) throws UndefinedParameterError {
        int batchSize = this.getParameterAsInt(PARAMETER_BATCH_SIZE);
        int batchCount = 0;
        int[] classCount = new int[2];
        while (batchCount++ < batchSize && reader.hasNext()) {
            Example example = reader.next();
            example.setValue(batchAttribute, currentBatchNum);
            example.setWeight(1.0);
            int n = (int)example.getLabel();
            classCount[n] = classCount[n] + 1;
        }
        double[] classPriors = new double[]{(double)classCount[0] / (double)(--batchCount), (double)classCount[1] / (double)batchCount};
        return classPriors;
    }

    private double[] prepareExtendedBatch(ExampleSet extendedBatch) {
        int[] classCount = new int[2];
        for (Example example : extendedBatch) {
            example.setWeight(1.0);
            int n = (int)example.getLabel();
            classCount[n] = classCount[n] + 1;
        }
        double[] classPriors = new double[2];
        int sum = classCount[0] + classCount[1];
        classPriors[0] = (double)classCount[0] / (double)sum;
        classPriors[1] = (double)classCount[1] / (double)sum;
        return classPriors;
    }

    private double evaluatePredictions(ExampleSet exampleSet) {
        Iterator reader = exampleSet.iterator();
        int count = 0;
        int correct = 0;
        while (reader.hasNext()) {
            ++count;
            Example example = (Example)reader.next();
            if (example.getLabel() != example.getPredictedLabel()) continue;
            ++correct;
        }
        return (double)correct / (double)count;
    }

    private boolean trainAdditionalModel(ExampleSet trainingSet, Vector<BayBoostBaseModelInfo> modelInfo) throws OperatorException {
        Model model = this.trainBaseModel(trainingSet);
        WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(trainingSet = model.apply(trainingSet));
        if (!this.isModelUseful(wp.getContingencyMatrix())) {
            this.log("Discard model because of low advantage on training data.");
            return false;
        }
        modelInfo.add(new BayBoostBaseModelInfo(model, wp.getContingencyMatrix()));
        return true;
    }

    private boolean adjustBaseModelWeights(ExampleSet exampleSet, Vector<BayBoostBaseModelInfo> modelInfo) throws OperatorException {
        for (int j = 0; j < modelInfo.size(); ++j) {
            boolean stillUncoveredExamples;
            BayBoostBaseModelInfo consideredModelInfo = modelInfo.get(j);
            Model consideredModel = consideredModelInfo.getModel();
            ContingencyMatrix cm = consideredModelInfo.getContingencyMatrix();
            BayBoostStream.createOrReplacePredictedLabelFor(exampleSet, consideredModel);
            exampleSet = consideredModel.apply(exampleSet);
            if (!exampleSet.getAttributes().getPredictedLabel().isNominal()) {
                throw new UserError((Operator)this, 101, exampleSet.getAttributes().getLabel(), "BayBoostStream base learners");
            }
            WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(exampleSet);
            ContingencyMatrix cmNew = wp.getContingencyMatrix();
            if (!this.isModelUseful(cm)) {
                modelInfo.remove(j);
                --j;
                this.log("Discard base model because of low advantage.");
                continue;
            }
            consideredModelInfo = new BayBoostBaseModelInfo(consideredModel, cmNew);
            modelInfo.set(j, consideredModelInfo);
            boolean bl = stillUncoveredExamples = WeightedPerformanceMeasures.reweightExamples(exampleSet, cmNew, false) > 0.0;
            if (stillUncoveredExamples) continue;
            return false;
        }
        return true;
    }

    private boolean isModelUseful(ContingencyMatrix cm) {
        for (int row = 0; row < cm.getNumberOfPredictions(); ++row) {
            for (int col = 0; col < cm.getNumberOfClasses(); ++col) {
                if (!(Math.abs(cm.getLift(row, col) - 1.0) > 0.02)) continue;
                return true;
            }
        }
        return false;
    }

    private static void createOrReplacePredictedLabelFor(ExampleSet exampleSet, Model model) {
        Attribute predictedLabel = exampleSet.getAttributes().getPredictedLabel();
        if (predictedLabel != null) {
            exampleSet.getAttributes().remove(predictedLabel);
            exampleSet.getExampleTable().removeAttribute(predictedLabel);
        }
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeBoolean(PARAMETER_RESCALE_LABEL_PRIORS, "Specifies whether the proportion of labels should be equal by construction after first iteration .", false));
        types.add(new ParameterTypeInt(PARAMETER_BATCH_SIZE, "Size of the batches. Minimum number of examples used to train a model.", 1, Integer.MAX_VALUE, 100));
        types.add(new ParameterTypeDouble(PARAMETER_FRACTION_HOLD_OUT_SET, "Rel. size of hold out set for ensemble selection. Set to 0 to turn off.", 0.0, 1.0, 0.0));
        types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return types;
    }

    public static class BatchFilterCondition
    implements Condition {
        private static final long serialVersionUID = 7910713773299060449L;
        private final int batchNumber;
        private final Attribute attribute;

        public BatchFilterCondition(Attribute attribute, int batchNumber) {
            this.batchNumber = batchNumber;
            this.attribute = attribute;
        }

        @Override
        public boolean conditionOk(Example example) {
            return example.getValue(this.attribute) >= (double)this.batchNumber;
        }

        @Override
        public Condition duplicate() {
            return this;
        }
    }
}

