/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.lazy;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.annotation.ResourceConsumptionEstimator;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.lazy.KNNClassificationModel;
import com.rapidminer.operator.learner.lazy.KNNRegressionModel;
import com.rapidminer.operator.ports.metadata.DistanceMeasurePrecondition;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.OperatorResourceConsumptionHandler;
import com.rapidminer.tools.math.container.LinearList;
import com.rapidminer.tools.math.similarity.DistanceMeasure;
import com.rapidminer.tools.math.similarity.DistanceMeasureHelper;
import com.rapidminer.tools.math.similarity.DistanceMeasures;
import java.util.List;

public class KNNLearner
extends AbstractLearner {
    public static final String PARAMETER_K = "k";
    public static final String PARAMETER_WEIGHTED_VOTE = "weighted_vote";
    private DistanceMeasureHelper measureHelper = new DistanceMeasureHelper(this);

    public KNNLearner(OperatorDescription description) {
        super(description);
        this.getExampleSetInputPort().addPrecondition(new DistanceMeasurePrecondition(this.getExampleSetInputPort(), this));
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        DistanceMeasure measure = this.measureHelper.getInitializedMeasure(exampleSet);
        Attribute label = exampleSet.getAttributes().getLabel();
        if (label.isNominal()) {
            LinearList<Integer> samples = new LinearList<Integer>(measure);
            Attributes attributes = exampleSet.getAttributes();
            int valuesSize = attributes.size();
            for (Example example : exampleSet) {
                double[] values = new double[valuesSize];
                int i = 0;
                for (Attribute attribute : attributes) {
                    values[i] = example.getValue(attribute);
                    ++i;
                }
                int labelValue = (int)example.getValue(label);
                samples.add(values, labelValue);
                this.checkForStop();
            }
            return new KNNClassificationModel(exampleSet, samples, this.getParameterAsInt(PARAMETER_K), this.getParameterAsBoolean(PARAMETER_WEIGHTED_VOTE));
        }
        LinearList<Double> samples = new LinearList<Double>(measure);
        Attributes attributes = exampleSet.getAttributes();
        int valuesSize = attributes.size();
        for (Example example : exampleSet) {
            double[] values = new double[valuesSize];
            int i = 0;
            for (Attribute attribute : attributes) {
                values[i] = example.getValue(attribute);
                ++i;
            }
            double labelValue = example.getValue(label);
            samples.add(values, labelValue);
            this.checkForStop();
        }
        return new KNNRegressionModel(exampleSet, samples, this.getParameterAsInt(PARAMETER_K), this.getParameterAsBoolean(PARAMETER_WEIGHTED_VOTE));
    }

    @Override
    public Class<? extends PredictionModel> getModelClass() {
        return super.getModelClass();
    }

    @Override
    public boolean supportsCapability(OperatorCapability capability) {
        int measureType = 0;
        try {
            measureType = this.measureHelper.getSelectedMeasureType();
        }
        catch (Exception exception) {
            // empty catch block
        }
        switch (capability) {
            case BINOMINAL_ATTRIBUTES: 
            case POLYNOMINAL_ATTRIBUTES: {
                return measureType == 0 || measureType == 1;
            }
            case NUMERICAL_ATTRIBUTES: {
                return measureType == 0 || measureType == 3 || measureType == 2;
            }
            case POLYNOMINAL_LABEL: 
            case BINOMINAL_LABEL: 
            case NUMERICAL_LABEL: 
            case WEIGHTED_EXAMPLES: 
            case MISSING_VALUES: {
                return true;
            }
        }
        return false;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterTypeInt type = new ParameterTypeInt(PARAMETER_K, "The used number of nearest neighbors.", 1, Integer.MAX_VALUE, 1);
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeBoolean(PARAMETER_WEIGHTED_VOTE, "Indicates if the votes should be weighted by similarity.", false, false));
        types.addAll(DistanceMeasures.getParameterTypes(this));
        return types;
    }

    @Override
    public ResourceConsumptionEstimator getResourceConsumptionEstimator() {
        return OperatorResourceConsumptionHandler.getResourceConsumptionEstimator(this.getExampleSetInputPort(), KNNLearner.class, null);
    }
}

