/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.ExecutionUnit;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.Value;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.meta.AbstractMetaLearner;
import com.rapidminer.operator.learner.meta.AbstractWeightedPerformanceMeasures;
import com.rapidminer.operator.learner.meta.BayBoostBaseModelInfo;
import com.rapidminer.operator.learner.meta.BayBoostModel;
import com.rapidminer.operator.learner.meta.ContingencyMatrix;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.operator.ports.metadata.Precondition;
import com.rapidminer.operator.ports.metadata.PredictionModelMetaData;
import com.rapidminer.operator.ports.metadata.SimplePrecondition;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.RandomGenerator;
import java.lang.reflect.Constructor;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;

public abstract class AbstractBayesianBoosting
extends AbstractMetaLearner {
    protected final InputPort modelInput = (InputPort)this.getInputPorts().createPort("model");
    protected int currentIteration;
    protected final int GET_RANDOM_SAMPLE_SUBPROCESS = 1;
    protected InputPort innerRandomSamplePort;
    protected double[] oldWeights;
    double performance = 0.0;
    protected Model startModel;
    public static final String PARAMETER_RESCALE_LABEL_PRIORS = "rescale_label_priors";
    public static final String PARAMETER_USE_SUBSET_FOR_TRAINING = "use_subset_for_training";
    public static final String PARAMETER_ALLOW_MARGINAL_SKEWS = "allow_marginal_skews";
    public static final String PORT_INNER_EXAMPLE_SET = "Example Set";
    public static final String PARAMETER_FUZZY_PARTITION_SIZES = "fuzzy_partition_sizes";
    public static final String PARAMETER_FUZZY_EXAMPLE_REWEIGHTING = "fuzzy_reweighting";
    protected Constructor pmConstructor;
    protected Class pmClass;
    protected boolean fuzzyReweighting;

    public AbstractBayesianBoosting(OperatorDescription description) {
        super(description);
        this.modelInput.addPrecondition((Precondition)new SimplePrecondition(this.modelInput, (MetaData)new PredictionModelMetaData(PredictionModel.class, new ExampleSetMetaData()), false));
        this.addValue((Value)new ValueDouble("iteration", "The current iteration."){

            public double getDoubleValue() {
                return AbstractBayesianBoosting.this.currentIteration;
            }
        });
        this.initializeSubprocesses();
    }

    protected abstract BayBoostModel trainBoostingModel(ExampleSet var1, double[] var2) throws OperatorException;

    protected void initializeSubprocesses() {
        ExecutionUnit getRandomSample = this.addSubprocess(1);
        getRandomSample.setName("fetch random sample");
        this.innerRandomSamplePort = this.getSubprocess(1).getInnerSinks().createPort(PORT_INNER_EXAMPLE_SET, ExampleSet.class);
    }

    protected MetaData modifyExampleSetMetaData(ExampleSetMetaData unmodifiedMetaData) {
        AttributeMetaData weightAttribute = new AttributeMetaData("weight", 4, "weight");
        unmodifiedMetaData.addAttribute(weightAttribute);
        return super.modifyExampleSetMetaData(unmodifiedMetaData);
    }

    public boolean supportsCapability(OperatorCapability lc) {
        switch (lc) {
            case NO_LABEL: 
            case UPDATABLE: 
            case FORMULA_PROVIDER: {
                return false;
            }
        }
        return true;
    }

    public void doWork() throws OperatorException {
        if (this.exampleSetInput.getDataOrNull() == null) {
            ExampleSet exampleSet = this.getNewSample(null);
            this.exampleSetInput.receive((IOObject)exampleSet);
        }
        super.doWork();
    }

    protected ExampleSet getNewSample(Vector<BayBoostBaseModelInfo> modelInfo) throws OperatorException, UserError {
        ExecutionUnit samplingSubprocess = this.getSubprocess(1);
        samplingSubprocess.execute();
        ExampleSet trainingSet = (ExampleSet)this.innerRandomSamplePort.getData(ExampleSet.class);
        this.prepareWeights(trainingSet);
        this.applyPriorModel(trainingSet, null);
        boolean bootstrap = this.getBootstrap();
        if (bootstrap) {
            return null;
        }
        if (modelInfo != null) {
            for (BayBoostBaseModelInfo current : modelInfo) {
                Model model = current.getModel();
                trainingSet = model.apply(trainingSet);
                this.reweightExamplesWrapper(trainingSet, bootstrap);
            }
        }
        return trainingSet;
    }

    protected double[] prepareWeights(ExampleSet exampleSet) {
        int i;
        Attribute weightAttr = exampleSet.getAttributes().getWeight();
        if (weightAttr == null) {
            this.oldWeights = null;
            this.performance = exampleSet.size();
            return this.createNewWeightAttribute(exampleSet);
        }
        this.oldWeights = new double[exampleSet.size()];
        double[] priors = new double[exampleSet.getAttributes().getLabel().getMapping().size()];
        double totalWeight = 0.0;
        Iterator reader = exampleSet.iterator();
        for (i = 0; reader.hasNext() && i < this.oldWeights.length; ++i) {
            double weight;
            Example example = (Example)reader.next();
            if (example == null) continue;
            this.oldWeights[i] = weight = example.getWeight();
            int label = (int)example.getLabel();
            if (0 <= label && label < priors.length) {
                int n = label;
                priors[n] = priors[n] + weight;
                totalWeight += weight;
                continue;
            }
            example.setWeight(0.0);
        }
        this.performance = totalWeight;
        i = 0;
        while (i < priors.length) {
            int n = i++;
            priors[n] = priors[n] / totalWeight;
        }
        return priors;
    }

    private double[] createNewWeightAttribute(ExampleSet exampleSet) {
        Tools.createWeightAttribute((ExampleSet)exampleSet);
        Iterator exRead = exampleSet.iterator();
        int numClasses = exampleSet.getAttributes().getLabel().getMapping().getValues().size();
        double[] classPriors = new double[numClasses];
        int total = exampleSet.size();
        double invTotal = 1.0 / (double)total;
        if (!this.getParameterAsBoolean(PARAMETER_RESCALE_LABEL_PRIORS)) {
            while (exRead.hasNext()) {
                Example example = (Example)exRead.next();
                example.setWeight(1.0);
                int n = (int)example.getLabel();
                classPriors[n] = classPriors[n] + invTotal;
            }
        } else {
            while (exRead.hasNext()) {
                int n = (int)((Example)exRead.next()).getLabel();
                classPriors[n] = classPriors[n] + invTotal;
            }
            this.rescaleToEqualPriors(exampleSet, classPriors);
        }
        return classPriors;
    }

    protected void applyPriorModel(ExampleSet trainingSet, List<BayBoostBaseModelInfo> modelInfo) throws OperatorException {
        if (this.startModel != null) {
            AbstractWeightedPerformanceMeasures wp;
            ExampleSet resultSet = this.startModel.apply((ExampleSet)trainingSet.clone());
            try {
                wp = (AbstractWeightedPerformanceMeasures)this.pmConstructor.newInstance(resultSet);
            }
            catch (Exception e) {
                throw new OperatorException("cannot call reweightExamples");
            }
            this.reweightExamples(wp, resultSet);
            if (modelInfo != null) {
                modelInfo.add(new BayBoostBaseModelInfo(this.startModel, wp.getContingencyMatrix()));
            }
            PredictionModel.removePredictedLabel((ExampleSet)resultSet);
        }
    }

    private boolean getBootstrap() throws UndefinedParameterError {
        double splitRatio = this.getParameterAsDouble(PARAMETER_USE_SUBSET_FOR_TRAINING);
        boolean bootstrap = splitRatio > 0.0 && splitRatio < 1.0;
        return bootstrap;
    }

    protected AbstractWeightedPerformanceMeasures reweightExamplesWrapper(ExampleSet exampleSet, boolean bootstrap) throws OperatorException {
        AbstractWeightedPerformanceMeasures wp;
        if (bootstrap) {
            AbstractWeightedPerformanceMeasures wp2;
            SplittedExampleSet splittedSet = (SplittedExampleSet)exampleSet;
            try {
                wp2 = (AbstractWeightedPerformanceMeasures)this.pmConstructor.newInstance(splittedSet);
                this.performance = (Double)this.pmClass.getMethod("reweightExamples", ExampleSet.class, ContingencyMatrix.class, Boolean.TYPE, Boolean.TYPE).invoke(null, splittedSet, wp2.getContingencyMatrix(), this.getParameterAsBoolean(PARAMETER_ALLOW_MARGINAL_SKEWS), this.fuzzyReweighting);
            }
            catch (Exception e) {
                throw new OperatorException("cannot call reweightExamples");
            }
            splittedSet.selectSingleSubset(1);
            try {
                wp2 = (AbstractWeightedPerformanceMeasures)this.pmConstructor.newInstance(splittedSet);
            }
            catch (Exception e) {
                throw new OperatorException("cannot call reweightExamples");
            }
            return wp2;
        }
        try {
            wp = (AbstractWeightedPerformanceMeasures)this.pmConstructor.newInstance(exampleSet);
        }
        catch (Exception e) {
            throw new OperatorException("cannot call reweightExamples");
        }
        this.performance = this.reweightExamples(wp, exampleSet);
        return wp;
    }

    private void rescaleToEqualPriors(ExampleSet exampleSet, double[] currentPriors) {
        double[] weights = new double[currentPriors.length];
        for (int i = 0; i < weights.length; ++i) {
            weights[i] = 1.0 / ((double)weights.length * currentPriors[i]);
        }
        for (Example example : exampleSet) {
            example.setWeight(weights[(int)example.getLabel()]);
        }
    }

    protected double reweightExamples(AbstractWeightedPerformanceMeasures wp, ExampleSet exampleSet) throws OperatorException {
        double remainingWeight;
        boolean allowMarginalSkews = this.getParameterAsBoolean(PARAMETER_ALLOW_MARGINAL_SKEWS);
        try {
            remainingWeight = (Double)this.pmClass.getMethod("reweightExamples", ExampleSet.class, ContingencyMatrix.class, Boolean.TYPE, Boolean.TYPE).invoke(null, exampleSet, wp.getContingencyMatrix(), allowMarginalSkews, this.fuzzyReweighting);
        }
        catch (Exception e) {
            throw new OperatorException("cannot call reweightExamples");
        }
        return remainingWeight;
    }

    protected void readOptionalParameters() throws UserError {
        this.startModel = (Model)this.modelInput.getDataOrNull();
        if (this.startModel == null) {
            this.log(this.getName() + ": No model found in input.");
        }
    }

    protected Model trainBaseModel(ExampleSet exampleSet) throws OperatorException {
        Model model = this.applyInnerLearner(exampleSet);
        return model;
    }

    public List<ParameterType> getParameterTypes() {
        List types = super.getParameterTypes();
        ParameterTypeDouble type = new ParameterTypeDouble(PARAMETER_USE_SUBSET_FOR_TRAINING, "Fraction of examples used for training, remaining ones are used to estimate the confusion matrix. Set to 1 to turn off test set.", 0.0, 1.0, 1.0);
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeBoolean(PARAMETER_RESCALE_LABEL_PRIORS, "Specifies whether the proportion of labels should be equal by construction after first iteration .", false));
        types.add(new ParameterTypeBoolean(PARAMETER_ALLOW_MARGINAL_SKEWS, "Allow to skew the marginal distribution (P(x)) during learning.", true));
        types.add(new ParameterTypeBoolean(PARAMETER_FUZZY_EXAMPLE_REWEIGHTING, "Specifies whether the example weights should calculated in a fuzzy way.", false));
        types.add(new ParameterTypeBoolean(PARAMETER_FUZZY_PARTITION_SIZES, "Specifies if the counting of tp, np etc is based on confidences instead of crisp predictions.", false));
        types.addAll(RandomGenerator.getRandomGeneratorParameters((Operator)this));
        return types;
    }
}

