/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.lazy;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.UpdateablePredictionModel;
import com.rapidminer.tools.Tools;
import com.rapidminer.tools.container.Tupel;
import com.rapidminer.tools.math.container.GeometricDataCollection;
import java.util.ArrayList;
import java.util.Collection;

public class KNNRegressionModel
extends UpdateablePredictionModel {
    private static final long serialVersionUID = -6292869962412072573L;
    private int k;
    private int size;
    private GeometricDataCollection<Double> samples;
    private ArrayList<String> sampleAttributeNames;
    private boolean weightByDistance;

    public KNNRegressionModel(ExampleSet trainingSet, GeometricDataCollection<Double> samples, int k, boolean weightByDistance) {
        super(trainingSet);
        this.k = k;
        this.samples = samples;
        this.weightByDistance = weightByDistance;
        this.size = trainingSet.size();
        Attributes attributes = trainingSet.getAttributes();
        this.sampleAttributeNames = new ArrayList(attributes.size());
        for (Attribute attribute : attributes) {
            this.sampleAttributeNames.add(attribute.getName());
        }
    }

    @Override
    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
        ArrayList<Attribute> sampleAttributes = new ArrayList<Attribute>(this.sampleAttributeNames.size());
        Attributes attributes = exampleSet.getAttributes();
        for (String attributeName : this.sampleAttributeNames) {
            sampleAttributes.add(attributes.get(attributeName));
        }
        double[] values = new double[sampleAttributes.size()];
        for (Example example : exampleSet) {
            int i = 0;
            for (Attribute attribute : sampleAttributes) {
                values[i] = example.getValue(attribute);
                ++i;
            }
            double result = 0.0;
            if (!this.weightByDistance) {
                Collection<Double> neighbourLabels = this.samples.getNearestValues(this.k, values);
                for (double label : neighbourLabels) {
                    result += label;
                }
                result /= (double)this.k;
            } else {
                Collection<Tupel<Double, Double>> neighbourTupels = this.samples.getNearestValueDistances(this.k, values);
                double totalDistance = 0.0;
                for (Tupel<Double, Double> tupel : neighbourTupels) {
                    totalDistance += tupel.getFirst().doubleValue();
                }
                double totalSimilarity = 0.0;
                if (totalDistance == 0.0) {
                    totalDistance = 1.0;
                    totalSimilarity = this.k;
                } else {
                    totalSimilarity = Math.max(this.k - 1, 1);
                }
                for (Tupel<Double, Double> tupel : neighbourTupels) {
                    result += tupel.getSecond() * (1.0 - tupel.getFirst() / totalDistance) / totalSimilarity;
                }
            }
            example.setValue(predictedLabel, result);
        }
        return exampleSet;
    }

    @Override
    public void update(ExampleSet updateSet) throws OperatorException {
        Attribute label = updateSet.getAttributes().getLabel();
        if (label.isNominal()) {
            Attributes attributes = updateSet.getAttributes();
            int valuesSize = attributes.size();
            for (Example example : updateSet) {
                double[] values = new double[valuesSize];
                int i = 0;
                for (Attribute attribute : attributes) {
                    values[i] = example.getValue(attribute);
                    ++i;
                }
                double labelValue = example.getValue(label);
                this.samples.add(values, labelValue);
            }
        }
    }

    @Override
    public String toString() {
        StringBuffer buffer = new StringBuffer();
        if (this.weightByDistance) {
            buffer.append("Weighted ");
        }
        buffer.append(this.k + "-Nearest Neighbour model for regression." + Tools.getLineSeparator());
        buffer.append("The model contains " + this.size + " examples with " + this.sampleAttributeNames.size() + " dimensions.");
        return buffer.toString();
    }
}

