/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorChain;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.meta.SDEnsemble;
import com.rapidminer.operator.learner.meta.SDReweightMeasures;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetPassThroughRule;
import com.rapidminer.operator.ports.metadata.GeneratePredictionModelTransformationRule;
import com.rapidminer.operator.ports.metadata.SetRelation;
import com.rapidminer.operator.ports.metadata.SubprocessTransformRule;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.container.Pair;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Vector;

public class SDRulesetInduction
extends OperatorChain {
    private InputPort exampleSetInput = this.getInputPorts().createPort("training set", ExampleSet.class);
    private OutputPort trainingInnerSource = (OutputPort)this.getSubprocess(0).getInnerSources().createPort("training set");
    private InputPort modelInnerSink = this.getSubprocess(0).getInnerSinks().createPort("model", PredictionModel.class);
    private OutputPort modelOutput = (OutputPort)this.getOutputPorts().createPort("model");
    public static final String PARAMETER_ITERATIONS = "iterations";
    public static final String PARAMETER_RATIO_INTERNAL_BOOTSTRAP = "ratio_internal_bootstrap";
    public static final String PARAMETER_ROC_CONVEX_HULL_FILTER = "ROC_convex_hull_filter";
    public static final String PARAMETER_ADDITIVE_REWEIGHT = "additive_reweight";
    public static final String PARAMETER_GAMMA = "gamma";
    public static final String TIMES_COVERED = "TIMES_COVERED_SPECIAL_ATTRIB";
    public static final double MIN_ADVANTAGE = 0.001;
    private double performance = 0.0;
    private int currentIteration;

    public SDRulesetInduction(OperatorDescription description) {
        super(description, "Training");
        this.getTransformer().addRule(new ExampleSetPassThroughRule(this.exampleSetInput, this.trainingInnerSource, SetRelation.EQUAL){

            @Override
            public ExampleSetMetaData modifyExampleSet(ExampleSetMetaData metaData) {
                AttributeMetaData weightAttribute = new AttributeMetaData("weight", 4, "weight");
                weightAttribute.setValueSetRelation(SetRelation.UNKNOWN);
                metaData.addAttribute(weightAttribute);
                AttributeMetaData specialAttribute = new AttributeMetaData(SDRulesetInduction.TIMES_COVERED, 4, SDRulesetInduction.TIMES_COVERED);
                specialAttribute.setValueSetRelation(SetRelation.UNKNOWN);
                metaData.addAttribute(specialAttribute);
                return metaData;
            }
        });
        this.getTransformer().addRule(new SubprocessTransformRule(this.getSubprocess(0)));
        this.getTransformer().addRule(new GeneratePredictionModelTransformationRule(this.exampleSetInput, this.modelOutput, PredictionModel.class));
        this.addValue(new ValueDouble("performance", "The performance."){

            @Override
            public double getDoubleValue() {
                return SDRulesetInduction.this.performance;
            }
        });
        this.addValue(new ValueDouble("iteration", "The current iteration."){

            @Override
            public double getDoubleValue() {
                return SDRulesetInduction.this.currentIteration;
            }
        });
    }

    public static int getPosIndex(Attribute label) {
        return label.getMapping().getPositiveIndex();
    }

    private double[] prepareWeights(ExampleSet exampleSet) throws OperatorException {
        Attribute weightAttr = Tools.createWeightAttribute(exampleSet);
        Attribute timesCoveredAttrib = null;
        boolean additive = this.getParameterAsBoolean(PARAMETER_ADDITIVE_REWEIGHT);
        if (additive && (timesCoveredAttrib = exampleSet.getAttributes().get(TIMES_COVERED)) == null) {
            timesCoveredAttrib = Tools.createSpecialAttribute(exampleSet, TIMES_COVERED, 3);
            exampleSet.getExampleTable().addAttribute(timesCoveredAttrib);
        }
        Iterator exRead = exampleSet.iterator();
        int numPos = 0;
        int positiveClass = SDRulesetInduction.getPosIndex(exampleSet.getAttributes().getLabel());
        int negativeClass = 1 - positiveClass;
        while (exRead.hasNext()) {
            if (((Example)exRead.next()).getLabel() != (double)positiveClass) continue;
            ++numPos;
        }
        double[] classPriors = new double[2];
        classPriors[positiveClass] = (double)numPos / (double)exampleSet.size();
        classPriors[negativeClass] = 1.0 - classPriors[positiveClass];
        double posWeight = 0.5 / classPriors[positiveClass];
        double negWeight = 0.5 / classPriors[negativeClass];
        for (Example example : exampleSet) {
            double w = example.getLabel() == (double)positiveClass ? posWeight : negWeight;
            example.setValue(weightAttr, w);
            if (!additive) continue;
            example.setValue(timesCoveredAttrib, 0.0);
        }
        return classPriors;
    }

    private Model trainModel(ExampleSet exampleSet) throws OperatorException {
        this.trainingInnerSource.deliver(exampleSet);
        this.getSubprocess(0).execute();
        return (Model)this.modelInnerSink.getData();
    }

    @Override
    public void doWork() throws OperatorException {
        ExampleSet exampleSet = (ExampleSet)this.exampleSetInput.getData();
        if (exampleSet.getAttributes().getLabel() == null) {
            throw new UserError((Operator)this, 105);
        }
        SDEnsemble model = this.trainRuleset(exampleSet, this.prepareWeights(exampleSet));
        this.modelOutput.deliver(model);
    }

    private SDEnsemble trainRuleset(ExampleSet trainingSet, double[] classPriors) throws OperatorException {
        Vector<Pair<Model, double[][]>> modelInfo = new Vector<Pair<Model, double[][]>>();
        double splitRatio = this.getParameterAsDouble(PARAMETER_RATIO_INTERNAL_BOOTSTRAP);
        boolean bootstrap = splitRatio > 0.0 && splitRatio < 1.0;
        this.log(bootstrap ? "Bootstrapping enabled." : "Bootstrapping disabled.");
        int iterations = this.getParameterAsInt(PARAMETER_ITERATIONS);
        boolean roc_filter = this.getParameterAsBoolean(PARAMETER_ROC_CONVEX_HULL_FILTER);
        LinkedList<double[]> rocCurve = null;
        if (roc_filter) {
            rocCurve = new LinkedList<double[]>();
            rocCurve.add(new double[]{0.0, 0.0});
            rocCurve.add(new double[]{1.0, 1.0});
        }
        for (int i = 0; i < iterations; ++i) {
            this.currentIteration = i;
            ExampleSet splittedSet = trainingSet;
            if (bootstrap) {
                splittedSet = new SplittedExampleSet(trainingSet, splitRatio, 1, this.getParameterAsBoolean("use_local_random_seed"), this.getParameterAsInt("local_random_seed"));
                ((SplittedExampleSet)splittedSet).selectSingleSubset(0);
            }
            Model model = this.trainModel(splittedSet);
            ExampleSet resultSet = null;
            if (bootstrap) {
                ((SplittedExampleSet)splittedSet).selectSingleSubset(1);
                resultSet = model.apply(splittedSet);
            } else {
                resultSet = model.apply(trainingSet);
            }
            SDReweightMeasures wp = new SDReweightMeasures(resultSet);
            boolean additive = this.getParameterAsBoolean(PARAMETER_ADDITIVE_REWEIGHT);
            wp.setAdditive(additive);
            if (!additive) {
                wp.setGamma(this.getParameterAsDouble(PARAMETER_GAMMA));
            }
            double[][] modelWeightMatrix = new double[2][2];
            double tpr = 0.0;
            double fpr = 0.0;
            boolean defaultRule = false;
            int[][] predClasses = new int[][]{wp.getCoveredExamplesNumForPred(0), wp.getCoveredExamplesNumForPred(1)};
            int[] rowTotals = new int[]{predClasses[0][0] + predClasses[0][1], predClasses[1][0] + predClasses[1][1]};
            int total = rowTotals[0] + rowTotals[1];
            double cov0 = (double)rowTotals[0] / (double)total;
            double cov1 = (double)rowTotals[1] / (double)total;
            double prior0 = ((double)predClasses[0][0] + (double)predClasses[1][0]) / (double)total;
            double prior1 = ((double)predClasses[0][1] + (double)predClasses[1][1]) / (double)total;
            double bias0 = Math.abs((double)predClasses[0][0] / (double)rowTotals[0] - prior0);
            double bias1 = Math.abs((double)predClasses[1][0] / (double)rowTotals[1] - prior0);
            int subset = Double.isNaN(bias1) || cov0 * bias0 >= cov1 * bias1 ? 0 : 1;
            modelWeightMatrix[subset][0] = (double)predClasses[subset][0] / (double)rowTotals[subset];
            modelWeightMatrix[subset][1] = (double)predClasses[subset][1] / (double)rowTotals[subset];
            double ratio0 = (double)predClasses[subset][0] / (double)total / prior0;
            double ratio1 = (double)predClasses[subset][1] / (double)total / prior1;
            wp.reweightExamples(trainingSet, ratio0 > ratio1 ? 0 : 1, subset);
            if (roc_filter) {
                tpr = Math.max(ratio0, ratio1);
                fpr = Math.min(ratio0, ratio1);
            }
            boolean bl = defaultRule = cov0 == 0.0 || cov1 == 0.0;
            if (!(defaultRule || roc_filter && !this.isOnConvexHull(rocCurve, tpr, fpr))) {
                modelInfo.add(new Pair<Model, double[][]>(model, modelWeightMatrix));
            }
            this.inApplyLoop();
        }
        if (roc_filter) {
            StringBuffer message = new StringBuffer("The convex hull in ROC space contains the following points (TPr/FPr):" + com.rapidminer.tools.Tools.getLineSeparator());
            for (double[] tpfp : rocCurve) {
                message.append("(" + tpfp[0] + ", " + tpfp[1] + ") ");
            }
            this.log(message.toString());
        }
        short combinationMethod = this.getParameterAsBoolean(PARAMETER_ADDITIVE_REWEIGHT) ? (short)1 : 2;
        return new SDEnsemble(trainingSet, modelInfo, classPriors, combinationMethod);
    }

    private boolean isOnConvexHull(List<double[]> rocCurve, double tpr, double fpr) {
        double newSlope;
        double[] current;
        if (tpr <= 0.0 || tpr > 1.0 || fpr < 0.0 || fpr >= 1.0) {
            return false;
        }
        ListIterator<double[]> iter = rocCurve.listIterator();
        double slope = Double.POSITIVE_INFINITY;
        boolean fprGreater = true;
        while (fprGreater) {
            current = iter.next();
            boolean bl = fprGreater = fpr > current[1];
            if (fprGreater) {
                newSlope = (tpr - current[0]) / (fpr - current[1]);
                if (newSlope >= slope) {
                    iter.remove();
                    continue;
                }
                slope = newSlope;
                double finalSlope = (1.0 - current[0]) / (1.0 - current[1]);
                if (!(slope <= finalSlope)) continue;
                return false;
            }
            if (fpr == current[1]) {
                if (tpr > current[0]) {
                    rocCurve.set(iter.previousIndex(), new double[]{tpr, fpr});
                    continue;
                }
                return false;
            }
            double nextSlope = (current[0] - tpr) / (current[1] - fpr);
            if (slope > nextSlope) {
                rocCurve.add(iter.previousIndex(), new double[]{tpr, fpr});
                continue;
            }
            return false;
        }
        slope = (1.0 - tpr) / (1.0 - fpr);
        iter = rocCurve.listIterator(rocCurve.size());
        while (iter.hasPrevious()) {
            current = iter.previous();
            if (current[1] <= fpr) {
                return true;
            }
            newSlope = (current[0] - tpr) / (current[1] - fpr);
            if (current[1] < 1.0 && newSlope <= slope) {
                iter.remove();
                continue;
            }
            slope = newSlope;
        }
        return true;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterTypeInt type = new ParameterTypeInt(PARAMETER_ITERATIONS, "The maximum number of iterations.", 1, Integer.MAX_VALUE, 10);
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeDouble(PARAMETER_RATIO_INTERNAL_BOOTSTRAP, "Fraction of examples used for training (internal bootstrapping). If activated (value < 1) only the rest is used to estimate the biases.", 0.0, 1.0, 0.7));
        types.add(new ParameterTypeBoolean(PARAMETER_ROC_CONVEX_HULL_FILTER, "A parameter whether to discard all rules not lying on the convex hull in ROC space.", true));
        types.add(new ParameterTypeBoolean(PARAMETER_ADDITIVE_REWEIGHT, "If enabled then resampling is done by additive reweighting, otherwise by multiplicative reweighting.", true));
        types.add(new ParameterTypeDouble(PARAMETER_GAMMA, "Factor used for multiplicative reweighting. Has no effect in case of additive reweighting.", 0.0, 1.0, 0.9));
        types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return types;
    }
}

