/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.meta.AbstractMetaLearner;
import com.rapidminer.operator.learner.meta.ThresholdModel;
import com.rapidminer.operator.performance.EstimatedPerformance;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.ParameterTypeList;
import com.rapidminer.parameter.ParameterTypeString;
import com.rapidminer.parameter.conditions.BooleanParameterCondition;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.Tools;
import com.rapidminer.tools.math.optimization.ec.es.ESOptimization;
import com.rapidminer.tools.math.optimization.ec.es.Individual;
import java.util.LinkedList;
import java.util.List;

public class CostBasedThresholdLearner
extends AbstractMetaLearner {
    public static final String PARAMETER_CLASS_WEIGHTS = "class_weights";
    public static final String PARAMETER_ALLOW_UNKOWN_PREDICTIONS = "allow_unkown_predictions";
    public static final String PARAMETER_PREDICT_UNKNOWN_COSTS = "predict_unknown_costs";
    public static final String PARAMETER_TRAINING_RATIO = "training_ratio";
    public static final String PARAMETER_NUMBER_OF_ITERATIONS = "number_of_iterations";

    public CostBasedThresholdLearner(OperatorDescription description) {
        super(description);
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        Attribute label = exampleSet.getAttributes().getLabel();
        List<String[]> classWeights = this.getParameterList(PARAMETER_CLASS_WEIGHTS);
        if (!exampleSet.getAttributes().getLabel().isNominal()) {
            throw new UserError((Operator)this, 101, this.getName(), label.getName());
        }
        if (classWeights.size() == 0) {
            throw new UserError((Operator)this, 205, PARAMETER_CLASS_WEIGHTS);
        }
        double unknownWeight = this.getParameterAsBoolean(PARAMETER_ALLOW_UNKOWN_PREDICTIONS) ? this.getParameterAsDouble(PARAMETER_PREDICT_UNKNOWN_COSTS) : -1.0;
        double[] weights = new double[label.getMapping().size()];
        for (int i = 0; i < weights.length; ++i) {
            weights[i] = 1.0;
        }
        for (String[] classWeightArray : classWeights) {
            String className = classWeightArray[0];
            double classWeight = Double.valueOf(classWeightArray[1]);
            int index = label.getMapping().getIndex(className);
            weights[index] = classWeight;
        }
        LinkedList<String> weightList = new LinkedList<String>();
        for (double d : weights) {
            weightList.add(Tools.formatIntegerIfPossible(d));
        }
        this.log("Used class weights --> " + weightList + ", unknown weight: " + Tools.formatIntegerIfPossible(unknownWeight));
        return this.calculateThresholdModel(exampleSet, weights, unknownWeight);
    }

    private Model calculateThresholdModel(ExampleSet exampleSet, final double[] classWeights, final double unknownWeight) throws OperatorException {
        SplittedExampleSet trainingSet = new SplittedExampleSet(exampleSet, this.getParameterAsDouble(PARAMETER_TRAINING_RATIO), 2, this.getParameterAsBoolean("use_local_random_seed"), this.getParameterAsInt("local_random_seed"));
        trainingSet.selectSingleSubset(0);
        Model innerModel = this.applyInnerLearner(trainingSet);
        trainingSet.selectSingleSubset(1);
        final ExampleSet appliedTrainingSet = innerModel.apply(trainingSet);
        final Attribute label = appliedTrainingSet.getAttributes().getLabel();
        int numberOfGenerations = this.getParameterAsInt(PARAMETER_NUMBER_OF_ITERATIONS);
        ESOptimization optimization = new ESOptimization(0.0, 1.0, 5, classWeights.length, 0, numberOfGenerations, Math.max(1, numberOfGenerations / 10), 6, 0.4, true, 1, 0.9, false, false, RandomGenerator.getRandomGenerator(this), this){

            @Override
            public PerformanceVector evaluateIndividual(Individual individual) throws OperatorException {
                double costs = 0.0;
                double[] thresholds = individual.getValues();
                for (Example example : appliedTrainingSet) {
                    int predictionIndex = (int)example.getPredictedLabel();
                    String className = label.getMapping().mapIndex(predictionIndex);
                    double confidence = example.getConfidence(className);
                    if (confidence > thresholds[predictionIndex]) {
                        if (example.getLabel() == example.getPredictedLabel()) continue;
                        costs += classWeights[(int)example.getLabel()];
                        continue;
                    }
                    double usedWeight = unknownWeight;
                    if (unknownWeight < 0.0) {
                        usedWeight = classWeights[(int)example.getLabel()];
                    }
                    if (example.getLabel() != example.getPredictedLabel()) continue;
                    costs += usedWeight;
                }
                PerformanceVector performanceVector = new PerformanceVector();
                performanceVector.addCriterion(new EstimatedPerformance("Costs", costs, 1, true));
                return performanceVector;
            }
        };
        optimization.optimize();
        PredictionModel.removePredictedLabel(appliedTrainingSet);
        double[] bestValues = optimization.getBestValuesEver();
        return new ThresholdModel(appliedTrainingSet, innerModel, bestValues);
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterType type = new ParameterTypeList(PARAMETER_CLASS_WEIGHTS, "The weights for all classes, empty: using 1 for all classes. The costs for not classifying at all are defined with class name '?'.", (ParameterType)new ParameterTypeString("class_name", "The name of the class."), (ParameterType)new ParameterTypeDouble("weight", "The weight for this class.", 0.0, Double.POSITIVE_INFINITY, 1.0));
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeBoolean(PARAMETER_ALLOW_UNKOWN_PREDICTIONS, "This indicates if unkown predictions are allowed. If checked, the costs for unkown predictions must be specified.", false));
        type = new ParameterTypeDouble(PARAMETER_PREDICT_UNKNOWN_COSTS, "Use this cost value for predicting an example as unknown.", 0.0, Double.POSITIVE_INFINITY, 1.0);
        type.setExpert(false);
        type.registerDependencyCondition(new BooleanParameterCondition(this, PARAMETER_ALLOW_UNKOWN_PREDICTIONS, true, true));
        types.add(type);
        types.add(new ParameterTypeDouble(PARAMETER_TRAINING_RATIO, "Use this amount of input data for model learning and the rest for threshold optimization.", 0.0, 1.0, 0.7));
        types.add(new ParameterTypeInt(PARAMETER_NUMBER_OF_ITERATIONS, "Defines the number of optimization iterations.", 1, Integer.MAX_VALUE, 200));
        types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return types;
    }

    @Override
    public boolean supportsCapability(OperatorCapability capability) {
        switch (capability) {
            case NUMERICAL_LABEL: 
            case NO_LABEL: 
            case UPDATABLE: 
            case FORMULA_PROVIDER: {
                return false;
            }
        }
        return true;
    }
}

