/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.Value;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.meta.AbstractMetaLearner;
import com.rapidminer.operator.learner.meta.AbstractWeightedPerformanceMeasures;
import com.rapidminer.operator.learner.meta.BayBoostBaseModelInfo;
import com.rapidminer.operator.learner.meta.BayBoostModel;
import com.rapidminer.operator.learner.meta.ContingencyMatrix;
import com.rapidminer.operator.learner.meta.FuzzyWeightedPerformanceMeasures;
import com.rapidminer.operator.learner.meta.StandardWeightedPerformanceMeasures;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.operator.ports.metadata.Precondition;
import com.rapidminer.operator.ports.metadata.PredictionModelMetaData;
import com.rapidminer.operator.ports.metadata.SimplePrecondition;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.RandomGenerator;
import java.lang.reflect.Constructor;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;

public class FuzzyBayesianBoosting
extends AbstractMetaLearner {
    public static final String PARAMETER_ITERATIONS = "iterations";
    public static final String PARAMETER_USE_SUBSET_FOR_TRAINING = "use_subset_for_training";
    public static final String PARAMETER_RESCALE_LABEL_PRIORS = "rescale_label_priors";
    public static final String PARAMETER_ALLOW_MARGINAL_SKEWS = "allow_marginal_skews";
    public static final String PARAMETER_FUZZY_PARTITION_SIZES = "fuzzy_partition_sizes";
    public static final String PARAMETER_FUZZY_EXAMPLE_REWEIGHTING = "fuzzy_reweighting";
    public static final double MIN_ADVANTAGE = 0.001;
    private Model startModel;
    protected int currentIteration;
    private double performance = 0.0;
    private double[] oldWeights;
    private final InputPort modelInput = (InputPort)this.getInputPorts().createPort("model");
    protected Constructor pmConstructor;
    protected Class pmClass;
    protected boolean fuzzyReweighting;

    public FuzzyBayesianBoosting(OperatorDescription description) {
        super(description);
        this.modelInput.addPrecondition((Precondition)new SimplePrecondition(this.modelInput, (MetaData)new PredictionModelMetaData(PredictionModel.class, new ExampleSetMetaData()), false));
        this.addValue((Value)new ValueDouble("performance", "The performance."){

            public double getDoubleValue() {
                return FuzzyBayesianBoosting.this.performance;
            }
        });
        this.addValue((Value)new ValueDouble("iteration", "The current iteration."){

            public double getDoubleValue() {
                return FuzzyBayesianBoosting.this.currentIteration;
            }
        });
    }

    protected MetaData modifyExampleSetMetaData(ExampleSetMetaData unmodifiedMetaData) {
        AttributeMetaData weightAttribute = new AttributeMetaData("weight", 4, "weight");
        unmodifiedMetaData.addAttribute(weightAttribute);
        return super.modifyExampleSetMetaData(unmodifiedMetaData);
    }

    public boolean supportsCapability(OperatorCapability lc) {
        switch (lc) {
            case NO_LABEL: 
            case UPDATABLE: 
            case FORMULA_PROVIDER: {
                return false;
            }
        }
        return true;
    }

    public Model learn(ExampleSet exampleSet) throws OperatorException {
        this.readOptionalParameters();
        double[] classPriors = this.prepareWeights(exampleSet);
        double maxPrior = Double.NEGATIVE_INFINITY;
        double sumPriors = 0.0;
        for (int i = 0; i < classPriors.length; ++i) {
            if (classPriors[i] > maxPrior) {
                maxPrior = classPriors[i];
            }
            sumPriors += classPriors[i];
        }
        BayBoostModel model = com.rapidminer.tools.Tools.isEqual((double)sumPriors, (double)maxPrior) ? new BayBoostModel(exampleSet, new Vector(), classPriors) : this.trainBoostingModel(exampleSet, classPriors);
        if (this.oldWeights != null) {
            Iterator reader = exampleSet.iterator();
            int i = 0;
            while (reader.hasNext() && i < this.oldWeights.length) {
                ((Example)reader.next()).setWeight(this.oldWeights[i++]);
            }
        } else {
            Attribute weight = exampleSet.getAttributes().getWeight();
            exampleSet.getAttributes().remove(weight);
            exampleSet.getExampleTable().removeAttribute(weight);
        }
        return model;
    }

    protected double[] prepareWeights(ExampleSet exampleSet) {
        int i;
        Attribute weightAttr = exampleSet.getAttributes().getWeight();
        if (weightAttr == null) {
            this.oldWeights = null;
            this.performance = exampleSet.size();
            return this.createNewWeightAttribute(exampleSet);
        }
        this.oldWeights = new double[exampleSet.size()];
        double[] priors = new double[exampleSet.getAttributes().getLabel().getMapping().size()];
        double totalWeight = 0.0;
        Iterator reader = exampleSet.iterator();
        for (i = 0; reader.hasNext() && i < this.oldWeights.length; ++i) {
            double weight;
            Example example = (Example)reader.next();
            if (example == null) continue;
            this.oldWeights[i] = weight = example.getWeight();
            int label = (int)example.getLabel();
            if (0 <= label && label < priors.length) {
                int n = label;
                priors[n] = priors[n] + weight;
                totalWeight += weight;
                continue;
            }
            example.setWeight(0.0);
        }
        this.performance = totalWeight;
        i = 0;
        while (i < priors.length) {
            int n = i++;
            priors[n] = priors[n] / totalWeight;
        }
        return priors;
    }

    private double[] createNewWeightAttribute(ExampleSet exampleSet) {
        Tools.createWeightAttribute((ExampleSet)exampleSet);
        Iterator exRead = exampleSet.iterator();
        int numClasses = exampleSet.getAttributes().getLabel().getMapping().getValues().size();
        double[] classPriors = new double[numClasses];
        int total = exampleSet.size();
        double invTotal = 1.0 / (double)total;
        if (!this.getParameterAsBoolean(PARAMETER_RESCALE_LABEL_PRIORS)) {
            while (exRead.hasNext()) {
                Example example = (Example)exRead.next();
                example.setWeight(1.0);
                int n = (int)example.getLabel();
                classPriors[n] = classPriors[n] + invTotal;
            }
        } else {
            while (exRead.hasNext()) {
                int n = (int)((Example)exRead.next()).getLabel();
                classPriors[n] = classPriors[n] + invTotal;
            }
            this.rescaleToEqualPriors(exampleSet, classPriors);
        }
        return classPriors;
    }

    private void rescaleToEqualPriors(ExampleSet exampleSet, double[] currentPriors) {
        double[] weights = new double[currentPriors.length];
        for (int i = 0; i < weights.length; ++i) {
            weights[i] = 1.0 / ((double)weights.length * currentPriors[i]);
        }
        for (Example example : exampleSet) {
            example.setWeight(weights[(int)example.getLabel()]);
        }
    }

    protected Model trainBaseModel(ExampleSet exampleSet) throws OperatorException {
        Model model = this.applyInnerLearner(exampleSet);
        return model;
    }

    private void readOptionalParameters() throws UserError {
        this.startModel = (Model)this.modelInput.getDataOrNull();
        if (this.startModel == null) {
            this.log(this.getName() + ": No model found in input.");
        }
    }

    private void applyPriorModel(ExampleSet trainingSet, List<BayBoostBaseModelInfo> modelInfo) throws OperatorException {
        if (this.startModel != null) {
            AbstractWeightedPerformanceMeasures wp;
            ExampleSet resultSet = this.startModel.apply((ExampleSet)trainingSet.clone());
            try {
                wp = (AbstractWeightedPerformanceMeasures)this.pmConstructor.newInstance(resultSet);
            }
            catch (Exception e) {
                throw new OperatorException("cannot instantiate WeightedPerformanceMeasures");
            }
            this.reweightExamples(wp, resultSet);
            modelInfo.add(new BayBoostBaseModelInfo(this.startModel, wp.getContingencyMatrix()));
            PredictionModel.removePredictedLabel((ExampleSet)resultSet);
        }
    }

    private BayBoostModel trainBoostingModel(ExampleSet trainingSet, double[] classPriors) throws OperatorException {
        this.fuzzyReweighting = this.getParameterAsBoolean(PARAMETER_FUZZY_EXAMPLE_REWEIGHTING);
        try {
            this.pmClass = this.getParameterAsBoolean(PARAMETER_FUZZY_PARTITION_SIZES) ? FuzzyWeightedPerformanceMeasures.class : StandardWeightedPerformanceMeasures.class;
            this.pmConstructor = this.pmClass.getConstructor(ExampleSet.class);
        }
        catch (Exception e) {
            this.pmConstructor = null;
        }
        Vector<BayBoostBaseModelInfo> modelInfo = new Vector<BayBoostBaseModelInfo>();
        this.applyPriorModel(trainingSet, modelInfo);
        double splitRatio = this.getParameterAsDouble(PARAMETER_USE_SUBSET_FOR_TRAINING);
        boolean bootstrap = splitRatio > 0.0 && splitRatio < 1.0;
        this.log(bootstrap ? "Bootstrapping enabled." : "Bootstrapping disabled.");
        boolean allowSkew = this.getParameterAsBoolean(PARAMETER_ALLOW_MARGINAL_SKEWS);
        SplittedExampleSet splittedSet = null;
        if (bootstrap) {
            splittedSet = new SplittedExampleSet(trainingSet, splitRatio, 1, this.getParameterAsBoolean("use_local_random_seed"), this.getParameterAsInt("local_random_seed"));
        }
        int iterations = this.getParameterAsInt(PARAMETER_ITERATIONS);
        for (int i = 0; i < iterations; ++i) {
            AbstractWeightedPerformanceMeasures wp;
            Model model;
            this.currentIteration = i;
            ExampleSet iterationSet = (ExampleSet)trainingSet.clone();
            if (bootstrap) {
                splittedSet.selectSingleSubset(0);
                model = this.trainBaseModel((ExampleSet)splittedSet);
                iterationSet = model.apply(iterationSet);
                try {
                    wp = (AbstractWeightedPerformanceMeasures)this.pmConstructor.newInstance(splittedSet);
                }
                catch (Exception e) {
                    throw new OperatorException("cannot instantiate WeightedPerformanceMeasures");
                }
                try {
                    this.pmClass.getMethod("reweightExamples", ExampleSet.class, ContingencyMatrix.class, Boolean.TYPE, Boolean.TYPE).invoke(null, splittedSet, wp.getContingencyMatrix(), allowSkew, this.fuzzyReweighting);
                }
                catch (Exception e) {
                    throw new OperatorException("cannot call reweightExamples");
                }
                splittedSet.selectSingleSubset(1);
                try {
                    wp = (AbstractWeightedPerformanceMeasures)this.pmConstructor.newInstance(splittedSet);
                    this.performance = (Double)this.pmClass.getMethod("reweightExamples", ExampleSet.class, ContingencyMatrix.class, Boolean.TYPE, Boolean.TYPE).invoke(null, splittedSet, wp.getContingencyMatrix(), allowSkew, this.fuzzyReweighting);
                }
                catch (Exception e) {
                    throw new OperatorException("cannot call reweightExamples");
                }
            }
            model = this.trainBaseModel(iterationSet);
            iterationSet = model.apply(iterationSet);
            try {
                wp = (AbstractWeightedPerformanceMeasures)this.pmConstructor.newInstance(iterationSet);
            }
            catch (Exception e) {
                throw new OperatorException("cannot instantiate WeightedPerformanceMeasures");
            }
            this.performance = this.reweightExamples(wp, iterationSet);
            PredictionModel.removePredictedLabel((ExampleSet)iterationSet);
            if (classPriors.length == 2) {
                // empty if block
            }
            if (wp.getNumberOfNonEmptyClasses() < 2) {
                modelInfo.add(new BayBoostBaseModelInfo(model, wp.getContingencyMatrix()));
                break;
            }
            ContingencyMatrix cm = wp.getContingencyMatrix();
            modelInfo.add(new BayBoostBaseModelInfo(model, cm));
            if (!this.isModelUseful(cm)) {
                this.log("Discard model because of low advantage on training data.");
                modelInfo.remove(modelInfo.size() - 1);
                break;
            }
            if (this.performance == 0.0) break;
            this.inApplyLoop();
        }
        return new BayBoostModel(trainingSet, modelInfo, classPriors);
    }

    protected double reweightExamples(AbstractWeightedPerformanceMeasures wp, ExampleSet exampleSet) throws OperatorException {
        double remainingWeight;
        boolean allowMarginalSkews = this.getParameterAsBoolean(PARAMETER_ALLOW_MARGINAL_SKEWS);
        try {
            remainingWeight = (Double)this.pmClass.getMethod("reweightExamples", ExampleSet.class, ContingencyMatrix.class, Boolean.TYPE, Boolean.TYPE).invoke(null, exampleSet, wp.getContingencyMatrix(), allowMarginalSkews, this.fuzzyReweighting);
        }
        catch (Exception e) {
            throw new OperatorException("cannot call reweightExamples");
        }
        return remainingWeight;
    }

    private boolean isModelUseful(ContingencyMatrix cm) {
        return true;
    }

    public List<ParameterType> getParameterTypes() {
        List types = super.getParameterTypes();
        ParameterTypeDouble type = new ParameterTypeDouble(PARAMETER_USE_SUBSET_FOR_TRAINING, "Fraction of examples used for training, remaining ones are used to estimate the confusion matrix. Set to 1 to turn off test set.", 0.0, 1.0, 1.0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeInt(PARAMETER_ITERATIONS, "The maximum number of iterations.", 1, Integer.MAX_VALUE, 10);
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeBoolean(PARAMETER_RESCALE_LABEL_PRIORS, "Specifies whether the proportion of labels should be equal by construction after first iteration .", false));
        types.add(new ParameterTypeBoolean(PARAMETER_ALLOW_MARGINAL_SKEWS, "Allow to skew the marginal distribution (P(x)) during learning.", true));
        types.add(new ParameterTypeBoolean(PARAMETER_FUZZY_EXAMPLE_REWEIGHTING, "Specifies whether the example weights should be calculated based on fuzzy sets.", false));
        types.add(new ParameterTypeBoolean(PARAMETER_FUZZY_PARTITION_SIZES, "Specifies if the counting of tp, np etc is based on confidences instead of crisp predictions.", false));
        types.addAll(RandomGenerator.getRandomGeneratorParameters((Operator)this));
        return types;
    }
}

