/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.meta.AbstractMetaLearner;
import com.rapidminer.operator.learner.meta.AdaBoostModel;
import com.rapidminer.operator.learner.meta.AdaBoostPerformanceMeasures;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeInt;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;

public class AdaBoost
extends AbstractMetaLearner {
    public static final String PARAMETER_ITERATIONS = "iterations";
    public static final double MIN_ADVANTAGE = 0.001;
    protected int currentIteration;
    private double performance = 0.0;
    private double[] oldWeights;

    public AdaBoost(OperatorDescription description) {
        super(description);
        this.addValue(new ValueDouble("performance", "The performance."){

            @Override
            public double getDoubleValue() {
                return AdaBoost.this.performance;
            }
        });
        this.addValue(new ValueDouble("iteration", "The current iteration."){

            @Override
            public double getDoubleValue() {
                return AdaBoost.this.currentIteration;
            }
        });
    }

    @Override
    protected MetaData modifyExampleSetMetaData(ExampleSetMetaData unmodifiedMetaData) {
        AttributeMetaData weightAttribute = new AttributeMetaData("weight", 4, "weight");
        unmodifiedMetaData.addAttribute(weightAttribute);
        return super.modifyExampleSetMetaData(unmodifiedMetaData);
    }

    @Override
    public boolean supportsCapability(OperatorCapability lc) {
        switch (lc) {
            case NUMERICAL_LABEL: 
            case NO_LABEL: 
            case UPDATABLE: 
            case FORMULA_PROVIDER: {
                return false;
            }
        }
        return true;
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        if (!exampleSet.getAttributes().getLabel().isNominal()) {
            throw new UserError((Operator)this, 119, exampleSet.getAttributes().getLabel().getName(), this.getName());
        }
        this.performance = this.prepareWeights(exampleSet);
        AdaBoostModel model = this.trainBoostingModel(exampleSet);
        Attribute weightAttribute = exampleSet.getAttributes().getWeight();
        if (this.oldWeights != null) {
            Iterator reader = exampleSet.iterator();
            int i = 0;
            while (reader.hasNext() && i < this.oldWeights.length) {
                ((Example)reader.next()).setValue(weightAttribute, this.oldWeights[i++]);
            }
        } else {
            exampleSet.getAttributes().remove(weightAttribute);
            exampleSet.getExampleTable().removeAttribute(weightAttribute);
        }
        return model;
    }

    protected double prepareWeights(ExampleSet exampleSet) {
        Attribute weightAttr = exampleSet.getAttributes().getWeight();
        double totalWeight = 0.0;
        if (weightAttr == null) {
            this.oldWeights = null;
            weightAttr = Tools.createWeightAttribute(exampleSet);
            Iterator exRead = exampleSet.iterator();
            while (exRead.hasNext()) {
                ((Example)exRead.next()).setValue(weightAttr, 1.0);
                totalWeight += 1.0;
            }
        } else {
            this.oldWeights = new double[exampleSet.size()];
            Iterator reader = exampleSet.iterator();
            for (int i = 0; reader.hasNext() && i < this.oldWeights.length; ++i) {
                this.oldWeights[i] = ((Example)reader.next()).getWeight();
                totalWeight += this.oldWeights[i];
            }
        }
        return totalWeight;
    }

    private AdaBoostModel trainBoostingModel(ExampleSet trainingSet) throws OperatorException {
        this.log("Total weight of example set at the beginning: " + this.performance);
        Vector<Model> ensembleModels = new Vector<Model>();
        Vector<Double> ensembleWeights = new Vector<Double>();
        int iterations = this.getParameterAsInt(PARAMETER_ITERATIONS);
        for (int i = 0; i < iterations && this.performance > 0.0; ++i) {
            this.currentIteration = i;
            ExampleSet iterationSet = (ExampleSet)trainingSet.clone();
            Model model = this.applyInnerLearner(iterationSet);
            iterationSet = model.apply(iterationSet);
            AdaBoostPerformanceMeasures wp = new AdaBoostPerformanceMeasures(iterationSet);
            this.performance = wp.reweightExamples(iterationSet);
            PredictionModel.removePredictedLabel(iterationSet);
            this.log("Total weight of example set after iteration " + (this.currentIteration + 1) + " is " + this.performance);
            if (!this.isModelUseful(wp)) {
                this.log("Discard model because of low advantage on training data.");
                return new AdaBoostModel(trainingSet, ensembleModels, ensembleWeights);
            }
            ensembleModels.add(model);
            double errorRate = wp.getErrorRate();
            double weight = errorRate == 0.0 ? Double.POSITIVE_INFINITY : Math.log((1.0 - errorRate) / errorRate);
            ensembleWeights.add(weight);
        }
        AdaBoostModel resultModel = new AdaBoostModel(trainingSet, ensembleModels, ensembleWeights);
        return resultModel;
    }

    private boolean isModelUseful(AdaBoostPerformanceMeasures wp) {
        return wp.getErrorRate() < 0.5;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterTypeInt type = new ParameterTypeInt(PARAMETER_ITERATIONS, "The maximum number of iterations.", 1, Integer.MAX_VALUE, 10);
        type.setExpert(false);
        types.add(type);
        return types;
    }
}

