/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.functions.kernel.evosvm;

import com.rapidminer.datatable.SimpleDataTable;
import com.rapidminer.datatable.SimpleDataTableRow;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCreationException;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.functions.kernel.SupportVector;
import com.rapidminer.operator.learner.functions.kernel.evosvm.ClassificationOptimizationFunction;
import com.rapidminer.operator.learner.functions.kernel.evosvm.EvoOptimization;
import com.rapidminer.operator.learner.functions.kernel.evosvm.EvoSVM;
import com.rapidminer.operator.learner.functions.kernel.evosvm.EvoSVMModel;
import com.rapidminer.operator.learner.functions.kernel.evosvm.OptimizationFunction;
import com.rapidminer.operator.performance.EstimatedPerformance;
import com.rapidminer.operator.performance.PerformanceEvaluator;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.tools.LoggingHandler;
import com.rapidminer.tools.OperatorService;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.math.kernels.Kernel;
import com.rapidminer.tools.math.optimization.ec.es.ESOptimization;
import com.rapidminer.tools.math.optimization.ec.es.Individual;
import com.rapidminer.tools.math.optimization.ec.es.NonDominatedSortingSelection;
import com.rapidminer.tools.math.optimization.ec.es.Population;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;

public class ClassificationEvoOptimization
extends ESOptimization
implements EvoOptimization {
    private static final double IS_ZERO = 1.0E-8;
    private ExampleSet exampleSet;
    private Kernel kernel;
    private double c;
    private double[] ys;
    private OptimizationFunction optimizationFunction;
    private ExampleSet holdOutSet = null;
    private int populationSize = 10;

    public ClassificationEvoOptimization(ExampleSet exampleSet, Kernel kernel, double c, int initType, int maxIterations, int generationsWithoutImprovement, int popSize, int selectionType, double tournamentFraction, boolean keepBest, int mutationType, double crossoverProb, boolean showConvergencePlot, boolean showPopulationPlot, ExampleSet holdOutSet, RandomGenerator random, LoggingHandler logging) {
        super(EvoSVM.createBoundArray(0.0, exampleSet.size()), EvoSVM.determineMax(c, kernel, exampleSet, selectionType, exampleSet.size()), popSize, exampleSet.size(), initType, maxIterations, generationsWithoutImprovement, selectionType, tournamentFraction, keepBest, mutationType, Double.NaN, crossoverProb, showConvergencePlot, showPopulationPlot, random, logging);
        this.exampleSet = exampleSet;
        this.holdOutSet = holdOutSet;
        this.populationSize = popSize;
        this.kernel = kernel;
        this.c = this.getMax(0);
        this.ys = new double[exampleSet.size()];
        Iterator reader = exampleSet.iterator();
        int index = 0;
        Attribute label = exampleSet.getAttributes().getLabel();
        while (reader.hasNext()) {
            Example example = (Example)reader.next();
            this.ys[index++] = example.getLabel() == (double)label.getMapping().getPositiveIndex() ? 1.0 : -1.0;
        }
        this.optimizationFunction = new ClassificationOptimizationFunction(selectionType == 7);
    }

    @Override
    public PerformanceVector evaluateIndividual(Individual individual) {
        double[] fitness = this.optimizationFunction.getFitness(individual.getValues(), this.ys, this.kernel);
        PerformanceVector performanceVector = new PerformanceVector();
        if (fitness.length == 1) {
            performanceVector.addCriterion(new EstimatedPerformance("SVM_fitness", fitness[0], 1, false));
        } else {
            performanceVector.addCriterion(new EstimatedPerformance("alpha_sum", fitness[0], 1, false));
            performanceVector.addCriterion(new EstimatedPerformance("svm_objective_function", fitness[1], 1, false));
            if (fitness.length == 3) {
                performanceVector.addCriterion(new EstimatedPerformance("alpha_label_sum", fitness[2], 1, false));
            }
        }
        return performanceVector;
    }

    @Override
    public EvoSVMModel train() throws OperatorException {
        this.optimize();
        if (this.holdOutSet != null) {
            SimpleDataTable holdOutSetPerfomance = new SimpleDataTable("Generalization Performance", new String[]{"individual", "training error", "test error"});
            Population population = this.getPopulation();
            NonDominatedSortingSelection selection = new NonDominatedSortingSelection(this.populationSize);
            selection.operate(population);
            class TrainingTestError
            implements Comparable<TrainingTestError> {
                private double trainingError;
                private double testError;
                private double[] alphas;

                TrainingTestError(double trainingError, double testError, double[] alphas) {
                    this.trainingError = trainingError;
                    this.testError = testError;
                    this.alphas = alphas;
                }

                @Override
                public int compareTo(TrainingTestError o) {
                    return -1 * Double.compare(this.trainingError, o.trainingError);
                }

                public boolean equals(Object o) {
                    if (!(o instanceof TrainingTestError)) {
                        return false;
                    }
                    return this.trainingError == ((TrainingTestError)o).trainingError;
                }

                public int hashCode() {
                    return Double.valueOf(this.trainingError).hashCode();
                }
            }
            LinkedList<TrainingTestError> errorList = new LinkedList<TrainingTestError>();
            for (int i = 0; i < population.getNumberOfIndividuals(); ++i) {
                double[] currentValues = population.get(i).getValues();
                EvoSVMModel model = null;
                try {
                    model = this.getModel(currentValues);
                }
                catch (IllegalArgumentException e) {
                    // empty catch block
                }
                if (model == null) continue;
                double trainingError = this.getError(this.exampleSet, model);
                double testError = this.getError(this.holdOutSet, model);
                errorList.add(new TrainingTestError(trainingError, testError, currentValues));
            }
            Collections.sort(errorList);
            Iterator i = errorList.iterator();
            int counter = 0;
            int bestIndex = -1;
            double bestValue = Double.POSITIVE_INFINITY;
            while (i.hasNext()) {
                TrainingTestError error = (TrainingTestError)i.next();
                holdOutSetPerfomance.add(new SimpleDataTableRow(new double[]{counter, error.trainingError, error.testError}));
                if (error.testError < bestValue) {
                    bestIndex = counter;
                    bestValue = error.testError;
                }
                ++counter;
            }
            return this.getModel(((TrainingTestError)errorList.get(bestIndex)).alphas);
        }
        return this.getModel(this.getBestValuesEver());
    }

    private double getError(ExampleSet exampleSet, Model model) throws OperatorException {
        exampleSet = model.apply(exampleSet);
        try {
            PerformanceEvaluator evaluator = OperatorService.createOperator(PerformanceEvaluator.class);
            evaluator.setParameter("classification_error", "true");
            PerformanceVector performance = evaluator.doWork(exampleSet);
            return performance.getMainCriterion().getAverage();
        }
        catch (OperatorCreationException e) {
            e.printStackTrace();
            return Double.NaN;
        }
    }

    @Override
    public PerformanceVector getOptimizationPerformance() {
        double[] bestValuesEver = this.getBestValuesEver();
        double[] finalFitness = this.optimizationFunction.getFitness(bestValuesEver, this.ys, this.kernel);
        PerformanceVector result = new PerformanceVector();
        if (finalFitness.length == 1) {
            result.addCriterion(new EstimatedPerformance("svm_objective_function", finalFitness[0], 1, false));
        } else {
            result.addCriterion(new EstimatedPerformance("alpha_sum", finalFitness[0], 1, false));
            result.addCriterion(new EstimatedPerformance("svm_objective_function", finalFitness[1], 1, false));
            if (finalFitness.length == 3) {
                result.addCriterion(new EstimatedPerformance("alpha_label_sum", finalFitness[2], 1, false));
            }
        }
        return result;
    }

    private EvoSVMModel getModel(double[] alphas) {
        int i;
        Iterator reader = this.exampleSet.iterator();
        ArrayList<SupportVector> supportVectors = new ArrayList<SupportVector>();
        int index = 0;
        while (reader.hasNext()) {
            double currentAlpha = alphas[index];
            Example currentExample = (Example)reader.next();
            if (currentAlpha != 0.0) {
                double[] x = new double[this.exampleSet.getAttributes().size()];
                int a = 0;
                for (Attribute attribute : this.exampleSet.getAttributes()) {
                    x[a++] = currentExample.getValue(attribute);
                }
                supportVectors.add(new SupportVector(x, this.ys[index], currentAlpha));
            }
            ++index;
        }
        double[] sum = new double[this.exampleSet.size()];
        reader = this.exampleSet.iterator();
        index = 0;
        while (reader.hasNext()) {
            Example current = (Example)reader.next();
            double[] x = new double[this.exampleSet.getAttributes().size()];
            int a = 0;
            for (Attribute attribute : this.exampleSet.getAttributes()) {
                x[a++] = current.getValue(attribute);
            }
            sum[index] = this.kernel.getSum(supportVectors, x);
            ++index;
        }
        double bSum = 0.0;
        int bCounter = 0;
        for (i = 0; i < alphas.length; ++i) {
            if (this.ys[i] * alphas[i] - this.c < -1.0E-8 && this.ys[i] * alphas[i] > 1.0E-8) {
                bSum += this.ys[i] - sum[i];
                ++bCounter;
                continue;
            }
            if (!(this.ys[i] * alphas[i] + this.c > 1.0E-8) || !(this.ys[i] * alphas[i] < -1.0E-8)) continue;
            bSum += this.ys[i] - sum[i];
            ++bCounter;
        }
        if (bCounter == 0) {
            bSum = 0.0;
            for (i = 0; i < alphas.length; ++i) {
                if (!(this.ys[i] * alphas[i] < 1.0E-8) || !(this.ys[i] * alphas[i] > -1.0E-8)) continue;
                bSum += this.ys[i] - sum[i];
                ++bCounter;
            }
            if (bCounter == 0) {
                bSum = 0.0;
                for (i = 0; i < alphas.length; ++i) {
                    bSum += this.ys[i] - sum[i];
                    ++bCounter;
                }
            }
        }
        return new EvoSVMModel(this.exampleSet, supportVectors, this.kernel, bSum / (double)bCounter);
    }
}

