/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.lazy;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.UpdateablePredictionModel;
import com.rapidminer.tools.Tools;
import com.rapidminer.tools.container.Tupel;
import com.rapidminer.tools.math.container.GeometricDataCollection;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;

public class KNNClassificationModel
extends UpdateablePredictionModel {
    private static final long serialVersionUID = -6292869962412072573L;
    private int k;
    private int size;
    private GeometricDataCollection<Integer> samples;
    private ArrayList<String> sampleAttributeNames;
    private boolean weightByDistance;

    public KNNClassificationModel(ExampleSet trainingSet, GeometricDataCollection<Integer> samples, int k, boolean weightByDistance) {
        super(trainingSet);
        this.k = k;
        this.size = trainingSet.size();
        this.samples = samples;
        this.weightByDistance = weightByDistance;
        Attributes attributes = trainingSet.getAttributes();
        this.sampleAttributeNames = new ArrayList(attributes.size());
        for (Attribute attribute : attributes) {
            this.sampleAttributeNames.add(attribute.getName());
        }
    }

    @Override
    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
        ArrayList<Attribute> sampleAttributes = new ArrayList<Attribute>(this.sampleAttributeNames.size());
        Attributes attributes = exampleSet.getAttributes();
        for (String attributeName : this.sampleAttributeNames) {
            sampleAttributes.add(attributes.get(attributeName));
        }
        double[] values = new double[sampleAttributes.size()];
        for (Example example : exampleSet) {
            int index;
            int i = 0;
            for (Attribute attribute : sampleAttributes) {
                values[i] = example.getValue(attribute);
                ++i;
            }
            double[] counter = new double[predictedLabel.getMapping().size()];
            double totalDistance = 0.0;
            if (!this.weightByDistance || this.k == 1) {
                Collection<Integer> neighbourLabels = this.samples.getNearestValues(this.k, values);
                totalDistance = this.k;
                Iterator<Integer> i$ = neighbourLabels.iterator();
                while (i$.hasNext()) {
                    int index2;
                    int n = index2 = i$.next().intValue();
                    counter[n] = counter[n] + 1.0 / totalDistance;
                }
            } else {
                Collection<Tupel<Double, Integer>> neighbours = this.samples.getNearestValueDistances(this.k, values);
                for (Tupel<Double, Integer> tupel : neighbours) {
                    totalDistance += tupel.getFirst().doubleValue();
                }
                double totalSimilarity = 0.0;
                if (totalDistance == 0.0) {
                    totalDistance = 1.0;
                    totalSimilarity = this.k;
                } else {
                    totalSimilarity = Math.max(this.k - 1, 1);
                }
                for (Tupel<Double, Integer> tupel : neighbours) {
                    int n = tupel.getSecond();
                    counter[n] = counter[n] + (1.0 - tupel.getFirst() / totalDistance) / totalSimilarity;
                }
            }
            int mostFrequentIndex = Integer.MIN_VALUE;
            double mostFrequentFrequency = Double.NEGATIVE_INFINITY;
            for (index = 0; index < counter.length; ++index) {
                if (!(mostFrequentFrequency < counter[index])) continue;
                mostFrequentFrequency = counter[index];
                mostFrequentIndex = index;
            }
            if (mostFrequentIndex == Integer.MIN_VALUE) {
                example.setValue(predictedLabel, Double.NaN);
            } else {
                example.setValue(predictedLabel, mostFrequentIndex);
            }
            for (index = 0; index < counter.length; ++index) {
                example.setConfidence(predictedLabel.getMapping().mapIndex(index), counter[index]);
            }
        }
        return exampleSet;
    }

    @Override
    public void update(ExampleSet updateSet) throws OperatorException {
        Attribute label = updateSet.getAttributes().getLabel();
        if (label.isNominal()) {
            Attributes attributes = updateSet.getAttributes();
            int valuesSize = attributes.size();
            for (Example example : updateSet) {
                double[] values = new double[valuesSize];
                int i = 0;
                for (Attribute attribute : attributes) {
                    values[i] = example.getValue(attribute);
                    ++i;
                }
                int labelValue = (int)example.getValue(label);
                this.samples.add(values, labelValue);
            }
        }
    }

    @Override
    public String toString() {
        StringBuffer buffer = new StringBuffer();
        if (this.weightByDistance) {
            buffer.append("Weighted ");
        }
        buffer.append(this.k + "-Nearest Neighbour model for classification." + Tools.getLineSeparator());
        buffer.append("The model contains " + this.size + " examples with " + this.sampleAttributeNames.size() + " dimensions of the following classes:");
        buffer.append(Tools.getLineSeparator());
        for (String value : this.getTrainingHeader().getAttributes().getLabel().getMapping().getValues()) {
            buffer.append("  " + value + Tools.getLineSeparator());
        }
        return buffer.toString();
    }
}

