/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.kernel.evosvm;

import edu.udo.cs.yale.datatable.SimpleDataTable;
import edu.udo.cs.yale.datatable.SimpleDataTableRow;
import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.gui.plotter.SimplePlotterDialog;
import edu.udo.cs.yale.operator.IOContainer;
import edu.udo.cs.yale.operator.IOObject;
import edu.udo.cs.yale.operator.Model;
import edu.udo.cs.yale.operator.OperatorCreationException;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.learner.kernel.evosvm.ClassificationOptimizationFunction;
import edu.udo.cs.yale.operator.learner.kernel.evosvm.EvoOptimization;
import edu.udo.cs.yale.operator.learner.kernel.evosvm.EvoSVMModel;
import edu.udo.cs.yale.operator.learner.kernel.evosvm.Kernel;
import edu.udo.cs.yale.operator.learner.kernel.evosvm.OptimizationFunction;
import edu.udo.cs.yale.operator.learner.kernel.evosvm.SupportVector;
import edu.udo.cs.yale.operator.performance.EstimatedPerformance;
import edu.udo.cs.yale.operator.performance.PerformanceEvaluator;
import edu.udo.cs.yale.operator.performance.PerformanceVector;
import edu.udo.cs.yale.tools.LogService;
import edu.udo.cs.yale.tools.OperatorService;
import edu.udo.cs.yale.tools.RandomGenerator;
import edu.udo.cs.yale.tools.math.optimization.ec.es.ESOptimization;
import edu.udo.cs.yale.tools.math.optimization.ec.es.Individual;
import edu.udo.cs.yale.tools.math.optimization.ec.es.NonDominatedSortingSelection;
import edu.udo.cs.yale.tools.math.optimization.ec.es.Population;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;

public class PatternEvoOptimization
extends ESOptimization
implements EvoOptimization {
    private static final double IS_ZERO = 1.0E-8;
    private ExampleSet exampleSet;
    private Kernel kernel;
    private double c;
    private double[] ys;
    private OptimizationFunction optimizationFunction;
    private ExampleSet holdOutSet = null;
    private int populationSize = 10;

    public PatternEvoOptimization(ExampleSet exampleSet, Kernel kernel, double c, int initType, int maxIterations, int generationsWithoutImprovement, int popSize, int selectionType, double tournamentFraction, boolean keepBest, int mutationType, double crossoverProb, boolean showConvergencePlot, ExampleSet holdOutSet, RandomGenerator random) {
        super(0.0, PatternEvoOptimization.determineC(c, kernel, exampleSet, selectionType), popSize, exampleSet.size(), initType, maxIterations, generationsWithoutImprovement, selectionType, tournamentFraction, keepBest, mutationType, crossoverProb, showConvergencePlot, random);
        this.exampleSet = exampleSet;
        this.holdOutSet = holdOutSet;
        this.populationSize = popSize;
        this.kernel = kernel;
        this.c = this.getMax(0);
        this.ys = new double[exampleSet.size()];
        Iterator reader = exampleSet.iterator();
        int index = 0;
        Attribute label = exampleSet.getAttributes().getLabel();
        while (reader.hasNext()) {
            Example example = (Example)reader.next();
            double d = this.ys[index++] = example.getLabel() == (double)label.getMapping().getPositiveIndex() ? 1.0 : -1.0;
        }
        this.optimizationFunction = new ClassificationOptimizationFunction(selectionType == 7);
    }

    private static final double determineC(double _c, Kernel kernel, ExampleSet exampleSet, int selectionType) {
        kernel.init(exampleSet);
        if (selectionType == 7) {
            return 1000.0;
        }
        if (_c <= 0.0) {
            double c = 0.0;
            int i = 0;
            while (i < exampleSet.size()) {
                c += kernel.getDistance(i, i);
                ++i;
            }
            c = (double)exampleSet.size() / c;
            LogService.logMessage("Determine probably good value for C: set to " + c, 2);
            return c;
        }
        return _c;
    }

    public PerformanceVector evaluateIndividual(Individual individual) {
        double[] fitness = this.optimizationFunction.getFitness(individual.getValues(), this.ys, this.kernel);
        PerformanceVector performanceVector = new PerformanceVector();
        if (fitness.length == 1) {
            performanceVector.addCriterion(new EstimatedPerformance("SVM_fitness", fitness[0], 1, false));
        } else {
            performanceVector.addCriterion(new EstimatedPerformance("alpha_sum", fitness[0], 1, false));
            performanceVector.addCriterion(new EstimatedPerformance("svm_objective_function", fitness[1], 1, false));
            if (fitness.length == 3) {
                performanceVector.addCriterion(new EstimatedPerformance("alpha_label_sum", fitness[2], 1, false));
            }
        }
        return performanceVector;
    }

    public EvoSVMModel train() throws OperatorException {
        this.optimize();
        if (this.holdOutSet != null) {
            SimpleDataTable holdOutSetPerfomance = new SimpleDataTable("Generalization Performance", new String[]{"individual", "training error", "test error"});
            Population population = this.getPopulation();
            NonDominatedSortingSelection selection = new NonDominatedSortingSelection(this.populationSize);
            selection.operate(population);
            /*
             * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
             */
            private class TrainingTestError
            implements Comparable<TrainingTestError> {
                private double trainingError;
                private double testError;
                private double[] alphas;

                TrainingTestError(double trainingError, double testError, double[] alphas) {
                    this.trainingError = trainingError;
                    this.testError = testError;
                    this.alphas = alphas;
                }

                @Override
                public int compareTo(TrainingTestError o) {
                    return -1 * Double.compare(this.trainingError, o.trainingError);
                }

                public boolean equals(Object o) {
                    if (!(o instanceof TrainingTestError)) {
                        return false;
                    }
                    return this.trainingError == ((TrainingTestError)o).trainingError;
                }

                public int hashCode() {
                    return Double.valueOf(this.trainingError).hashCode();
                }
            }
            LinkedList<TrainingTestError> errorList = new LinkedList<TrainingTestError>();
            int i = 0;
            while (i < population.getNumberOfIndividuals()) {
                double[] currentValues = population.get(i).getValues();
                EvoSVMModel model = null;
                try {
                    model = this.getModel(currentValues);
                }
                catch (IllegalArgumentException illegalArgumentException) {
                    // empty catch block
                }
                if (model != null) {
                    double trainingError = this.getError(this.exampleSet, model);
                    double testError = this.getError(this.holdOutSet, model);
                    errorList.add(new TrainingTestError(trainingError, testError, currentValues));
                }
                ++i;
            }
            Collections.sort(errorList);
            Iterator i2 = errorList.iterator();
            int counter = 0;
            int bestIndex = -1;
            double bestValue = Double.POSITIVE_INFINITY;
            while (i2.hasNext()) {
                TrainingTestError error = (TrainingTestError)i2.next();
                holdOutSetPerfomance.add(new SimpleDataTableRow(new double[]{counter, error.trainingError, error.testError}));
                if (error.testError < bestValue) {
                    bestIndex = counter;
                    bestValue = error.testError;
                }
                ++counter;
            }
            SimplePlotterDialog plotter = new SimplePlotterDialog(holdOutSetPerfomance, false);
            plotter.setXAxis(0);
            plotter.plotColumn(1, true);
            plotter.plotColumn(2, true);
            plotter.setDraw2DLines(false);
            plotter.setVisible(true);
            return this.getModel(((TrainingTestError)errorList.get(bestIndex)).alphas);
        }
        return this.getModel(this.getBestValuesEver());
    }

    private double getError(ExampleSet exampleSet, Model model) throws OperatorException {
        model.apply(exampleSet);
        try {
            PerformanceEvaluator evaluator = (PerformanceEvaluator)OperatorService.createOperator(PerformanceEvaluator.class);
            evaluator.setParameter("classification_error", "true");
            IOContainer result = evaluator.apply(new IOContainer(new IOObject[]{exampleSet}));
            PerformanceVector performance = result.get(PerformanceVector.class);
            return performance.getMainCriterion().getAverage();
        }
        catch (OperatorCreationException e) {
            e.printStackTrace();
            return Double.NaN;
        }
    }

    public PerformanceVector getOptimizationPerformance() {
        double[] bestValuesEver = this.getBestValuesEver();
        double[] finalFitness = this.optimizationFunction.getFitness(bestValuesEver, this.ys, this.kernel);
        PerformanceVector result = new PerformanceVector();
        if (finalFitness.length == 1) {
            result.addCriterion(new EstimatedPerformance("svm_objective_function", finalFitness[0], 1, false));
        } else {
            result.addCriterion(new EstimatedPerformance("alpha_sum", finalFitness[0], 1, false));
            result.addCriterion(new EstimatedPerformance("svm_objective_function", finalFitness[1], 1, false));
            if (finalFitness.length == 3) {
                result.addCriterion(new EstimatedPerformance("alpha_label_sum", finalFitness[2], 1, false));
            }
        }
        return result;
    }

    private EvoSVMModel getModel(double[] alphas) {
        Iterator reader = this.exampleSet.iterator();
        ArrayList<SupportVector> supportVectors = new ArrayList<SupportVector>();
        int index = 0;
        while (reader.hasNext()) {
            double currentAlpha = alphas[index];
            Example currentExample = (Example)reader.next();
            if (currentAlpha != 0.0) {
                double[] x = new double[this.exampleSet.getAttributes().size()];
                int a = 0;
                for (Attribute attribute : this.exampleSet.getAttributes()) {
                    x[a++] = currentExample.getValue(attribute);
                }
                supportVectors.add(new SupportVector(x, this.ys[index], currentAlpha));
            }
            ++index;
        }
        double[] sum = new double[this.exampleSet.size()];
        reader = this.exampleSet.iterator();
        index = 0;
        while (reader.hasNext()) {
            Example current = (Example)reader.next();
            double[] x = new double[this.exampleSet.getAttributes().size()];
            int a = 0;
            for (Attribute attribute : this.exampleSet.getAttributes()) {
                x[a++] = current.getValue(attribute);
            }
            sum[index] = this.kernel.getSum(supportVectors, x);
            ++index;
        }
        double bSum = 0.0;
        int bCounter = 0;
        int i = 0;
        while (i < alphas.length) {
            if (this.ys[i] * alphas[i] - this.c < -1.0E-8 && this.ys[i] * alphas[i] > 1.0E-8) {
                bSum += this.ys[i] - sum[i];
                ++bCounter;
            } else if (this.ys[i] * alphas[i] + this.c > 1.0E-8 && this.ys[i] * alphas[i] < -1.0E-8) {
                bSum += this.ys[i] - sum[i];
                ++bCounter;
            }
            ++i;
        }
        if (bCounter == 0) {
            bSum = 0.0;
            i = 0;
            while (i < alphas.length) {
                if (this.ys[i] * alphas[i] < 1.0E-8 && this.ys[i] * alphas[i] > -1.0E-8) {
                    bSum += this.ys[i] - sum[i];
                    ++bCounter;
                }
                ++i;
            }
            if (bCounter == 0) {
                bSum = 0.0;
                i = 0;
                while (i < alphas.length) {
                    bSum += this.ys[i] - sum[i];
                    ++bCounter;
                    ++i;
                }
            }
        }
        return new EvoSVMModel(this.exampleSet.getAttributes().getLabel(), supportVectors, this.kernel, bSum / (double)bCounter);
    }
}

