/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.meta;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.example.SplittedExampleSet;
import edu.udo.cs.yale.operator.Model;
import edu.udo.cs.yale.operator.Operator;
import edu.udo.cs.yale.operator.OperatorDescription;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.UserError;
import edu.udo.cs.yale.operator.learner.PredictionModel;
import edu.udo.cs.yale.operator.learner.meta.AbstractMetaLearner;
import edu.udo.cs.yale.operator.learner.meta.ThresholdModel;
import edu.udo.cs.yale.operator.parameter.ParameterType;
import edu.udo.cs.yale.operator.parameter.ParameterTypeDouble;
import edu.udo.cs.yale.operator.parameter.ParameterTypeInt;
import edu.udo.cs.yale.operator.parameter.ParameterTypeList;
import edu.udo.cs.yale.operator.performance.EstimatedPerformance;
import edu.udo.cs.yale.operator.performance.PerformanceVector;
import edu.udo.cs.yale.tools.LogService;
import edu.udo.cs.yale.tools.RandomGenerator;
import edu.udo.cs.yale.tools.Tools;
import edu.udo.cs.yale.tools.math.optimization.ec.es.ESOptimization;
import edu.udo.cs.yale.tools.math.optimization.ec.es.Individual;
import java.util.LinkedList;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class CostBasedThresholdLearner
extends AbstractMetaLearner {
    public CostBasedThresholdLearner(OperatorDescription description) {
        super(description);
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        Attribute label = exampleSet.getAttributes().getLabel();
        List classWeights = this.getParameterList("class_weights");
        if (!exampleSet.getAttributes().getLabel().isNominal()) {
            throw new UserError((Operator)this, 101, this.getName(), (Object)label.getName());
        }
        if (classWeights.size() == 0) {
            throw new UserError((Operator)this, 205, "class_weights");
        }
        double unknownWeight = this.getParameterAsDouble("predict_unknown_costs");
        double[] weights = new double[label.getMapping().size()];
        int i = 0;
        while (i < weights.length) {
            weights[i] = 1.0;
            ++i;
        }
        for (Object[] classWeightArray : classWeights) {
            String className = (String)classWeightArray[0];
            double classWeight = (Double)classWeightArray[1];
            int index = label.getMapping().getIndex(className);
            weights[index] = classWeight;
        }
        LinkedList<String> weightList = new LinkedList<String>();
        double[] dArray = weights;
        int n = 0;
        int n2 = dArray.length;
        while (n < n2) {
            double d = dArray[n];
            weightList.add(Tools.formatIntegerIfPossible(d));
            ++n;
        }
        LogService.logMessage(String.valueOf(this.getName()) + ": used class weights --> " + weightList + ", unknown weight: " + Tools.formatIntegerIfPossible(unknownWeight), 2);
        return this.calculateThresholdModel(exampleSet, weights, unknownWeight);
    }

    private Model calculateThresholdModel(ExampleSet exampleSet, final double[] classWeights, final double unknownWeight) throws OperatorException {
        final SplittedExampleSet trainingSet = new SplittedExampleSet(exampleSet, this.getParameterAsDouble("training_ratio"), 2, this.getParameterAsInt("local_random_seed"));
        trainingSet.selectSingleSubset(0);
        Model innerModel = this.applyInnerLearner(trainingSet);
        trainingSet.selectSingleSubset(1);
        innerModel.apply(trainingSet);
        final Attribute label = trainingSet.getAttributes().getLabel();
        int numberOfGenerations = this.getParameterAsInt("number_of_iterations");
        ESOptimization optimization = new ESOptimization(0.0, 1.0, 5, classWeights.length, 0, numberOfGenerations, Math.max(1, numberOfGenerations / 10), 6, 0.4, true, 1, 0.9, false, RandomGenerator.getRandomGenerator(this.getParameterAsInt("local_random_seed"))){

            public PerformanceVector evaluateIndividual(Individual individual) throws OperatorException {
                double costs = 0.0;
                double[] thresholds = individual.getValues();
                for (Example example : trainingSet) {
                    int predictionIndex = (int)example.getPredictedLabel();
                    String className = label.getMapping().mapIndex(predictionIndex);
                    double confidence = example.getConfidence(className);
                    if (confidence > thresholds[predictionIndex]) {
                        if (example.getLabel() == example.getPredictedLabel()) continue;
                        costs += classWeights[(int)example.getLabel()];
                        continue;
                    }
                    double usedWeight = unknownWeight;
                    if (unknownWeight < 0.0) {
                        usedWeight = classWeights[(int)example.getLabel()];
                    }
                    if (example.getLabel() != example.getPredictedLabel()) continue;
                    costs += usedWeight;
                }
                PerformanceVector performanceVector = new PerformanceVector();
                performanceVector.addCriterion(new EstimatedPerformance("Costs", costs, 1, true));
                return performanceVector;
            }
        };
        optimization.optimize();
        PredictionModel.removePredictedLabel(trainingSet);
        double[] bestValues = optimization.getBestValuesEver();
        return new ThresholdModel(label, innerModel, bestValues);
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterType type = new ParameterTypeList("class_weights", "The weights for all classes (first column: class names, second column: weight), empty: using 1 for all classes. The costs for not classifying at all are defined with class name '?'.", new ParameterTypeDouble("weight", "The weight for the specified class.", 0.0, Double.POSITIVE_INFINITY, 1.0));
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble("predict_unknown_costs", "Use this cost value for predicting an example as unknown (-1: use same costs as for correct class).", -1.0, Double.POSITIVE_INFINITY, -1.0);
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeDouble("training_ratio", "Use this amount of input data for model learning and the rest for threshold optimization.", 0.0, 1.0, 0.7));
        types.add(new ParameterTypeInt("number_of_iterations", "Defines the number of optimization iterations.", 1, Integer.MAX_VALUE, 200));
        types.add(new ParameterTypeInt("local_random_seed", "Use the given random seed instead of global random numbers (-1: use global)", -1, Integer.MAX_VALUE, -1));
        return types;
    }
}

