/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.bayes;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.Model;
import edu.udo.cs.yale.operator.OperatorDescription;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.learner.AbstractLearner;
import edu.udo.cs.yale.operator.learner.LearnerCapability;
import edu.udo.cs.yale.operator.learner.MixedDistributionsDistribution;
import edu.udo.cs.yale.operator.learner.bayes.DiscreteDistribution;
import edu.udo.cs.yale.operator.learner.bayes.DistributionModel;
import edu.udo.cs.yale.operator.learner.bayes.NormalDistribution;
import edu.udo.cs.yale.operator.parameter.ParameterType;
import edu.udo.cs.yale.operator.parameter.ParameterTypeBoolean;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class NaiveBayes
extends AbstractLearner {
    public NaiveBayes(OperatorDescription description) {
        super(description);
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        int exampleCount = exampleSet.size();
        int classCount = exampleSet.getAttributes().getLabel().getMapping().size();
        int attributeCount = exampleSet.getAttributes().size();
        Attribute classAttribute = exampleSet.getAttributes().getLabel();
        List<String> classMappings = classAttribute.getMapping().getValues();
        ArrayList classes = new ArrayList(classCount);
        int i = 0;
        while (i < classCount) {
            classes.add(new ArrayList());
            ++i;
        }
        for (Example currentExample : exampleSet) {
            int currentClass = this.getIndexOfString(classMappings, currentExample.getValueAsString(classAttribute));
            ((ArrayList)classes.get(currentClass)).add(currentExample);
            this.checkForStop();
        }
        boolean[] numericalAttribute = new boolean[attributeCount];
        boolean[] nominalAttribute = new boolean[attributeCount];
        String[] attributeNames = new String[attributeCount];
        int counter = 0;
        for (Attribute currentAttribute : exampleSet.getAttributes()) {
            if (currentAttribute.isNominal()) {
                nominalAttribute[counter] = true;
            } else {
                numericalAttribute[counter] = true;
            }
            attributeNames[counter] = currentAttribute.getName();
            ++counter;
        }
        double[] classProbabilities = new double[classCount];
        int i2 = 0;
        while (i2 < classCount) {
            classProbabilities[i2] = ((double)((ArrayList)classes.get(i2)).size() + 1.0) / (double)(exampleCount + classCount);
            ++i2;
        }
        DistributionModel model = new DistributionModel(classAttribute, classCount, classProbabilities, attributeNames);
        int i3 = 0;
        while (i3 < classCount) {
            int j = 0;
            for (Attribute currentAttribute : exampleSet.getAttributes()) {
                if (numericalAttribute[j]) {
                    if (!this.getParameterAsBoolean("use_kernel")) {
                        double mean = this.getMeanOfList((ArrayList)classes.get(i3), currentAttribute);
                        double variance = this.getDeviationOfList((ArrayList)classes.get(i3), currentAttribute, mean);
                        model.addDistribution(i3, j, new NormalDistribution(mean, variance));
                    } else {
                        double variance = 1.0 / Math.sqrt(((ArrayList)classes.get(i3)).size());
                        MixedDistributionsDistribution distribution = new MixedDistributionsDistribution();
                        for (Example example : (ArrayList)classes.get(i3)) {
                            distribution.addDistribution(new NormalDistribution(example.getValue(currentAttribute), variance));
                        }
                        model.addDistribution(i3, j, distribution);
                    }
                    this.checkForStop();
                }
                if (nominalAttribute[j]) {
                    LinkedHashSet<Double> set = new LinkedHashSet<Double>();
                    ArrayList<Double> tupel = new ArrayList<Double>();
                    Iterator listIterator = ((ArrayList)classes.get(i3)).iterator();
                    while (listIterator.hasNext()) {
                        Double value = ((Example)listIterator.next()).getValue(currentAttribute);
                        set.add(value);
                        tupel.add(value);
                    }
                    model.addDistribution(i3, j, new DiscreteDistribution(set, currentAttribute.getMapping().size(), tupel));
                }
                ++j;
            }
            ++i3;
        }
        return model;
    }

    private double getMeanOfList(ArrayList<Example> exampleCollection, Attribute attribute) {
        Iterator<Example> listIterator = exampleCollection.iterator();
        double accumulatedValue = 0.0;
        while (listIterator.hasNext()) {
            accumulatedValue += listIterator.next().getValue(attribute);
        }
        return accumulatedValue / (double)exampleCollection.size();
    }

    private double getDeviationOfList(ArrayList<Example> exampleCollection, Attribute attribute, double mean) {
        Iterator<Example> listIterator = exampleCollection.iterator();
        double accumulatedValue = 0.0;
        while (listIterator.hasNext()) {
            accumulatedValue += Math.pow(listIterator.next().getValue(attribute) - mean, 2.0);
        }
        return Math.sqrt(accumulatedValue / (double)(exampleCollection.size() - 1));
    }

    private int getIndexOfString(Collection<String> collection, String value) {
        Iterator<String> collectionReader = collection.iterator();
        int index = 0;
        while (collectionReader.hasNext()) {
            if (collectionReader.next().equals(value)) {
                return index;
            }
            ++index;
        }
        return -1;
    }

    @Override
    public boolean supportsCapability(LearnerCapability lc) {
        if (lc == LearnerCapability.POLYNOMINAL_ATTRIBUTES) {
            return true;
        }
        if (lc == LearnerCapability.BINOMINAL_ATTRIBUTES) {
            return true;
        }
        if (lc == LearnerCapability.NUMERICAL_ATTRIBUTES) {
            return true;
        }
        if (lc == LearnerCapability.POLYNOMINAL_CLASS) {
            return true;
        }
        return lc == LearnerCapability.BINOMINAL_CLASS;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterTypeBoolean type = new ParameterTypeBoolean("use_kernel", "Using kernels might reduce error", false);
        type.setExpert(false);
        types.add(type);
        return types;
    }
}

