/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.local;

import Jama.Matrix;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.HeaderExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.local.Neighborhood;
import com.rapidminer.tools.Tools;
import com.rapidminer.tools.container.Tupel;
import com.rapidminer.tools.math.LinearRegression;
import com.rapidminer.tools.math.VectorMath;
import com.rapidminer.tools.math.container.GeometricDataCollection;
import com.rapidminer.tools.math.smoothing.SmoothingKernel;
import java.io.Serializable;
import java.util.Collection;

public class LocalPolynomialRegressionModel
extends PredictionModel {
    private GeometricDataCollection<RegressionData> samples;
    private Neighborhood neighborhood;
    private SmoothingKernel kernelSmoother;
    private int degree;
    private double ridge;
    private static final long serialVersionUID = -4874020185611138104L;

    protected LocalPolynomialRegressionModel(ExampleSet trainingExampleSet, GeometricDataCollection<RegressionData> data, Neighborhood neighborhood, SmoothingKernel kernelSmoother, int degree, double ridge) {
        super(trainingExampleSet);
        this.samples = data;
        this.neighborhood = neighborhood;
        this.kernelSmoother = kernelSmoother;
        this.degree = degree;
        this.ridge = ridge;
    }

    @Override
    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
        Attributes attributes = exampleSet.getAttributes();
        double[] probe = new double[attributes.size()];
        for (Example example : exampleSet) {
            int i = 0;
            for (Attribute attribute : attributes) {
                probe[i] = example.getValue(attribute);
                ++i;
            }
            Collection<Tupel<Double, RegressionData>> localExamples = this.neighborhood.getNeighbourhood(this.samples, probe);
            if (localExamples.size() > 1) {
                double[][] x = new double[localExamples.size()][];
                double[][] y = new double[localExamples.size()][1];
                double[] distance = new double[localExamples.size()];
                double[] weight = new double[localExamples.size()];
                int j = 0;
                for (Tupel<Double, RegressionData> tupel : localExamples) {
                    distance[j] = tupel.getFirst();
                    x[j] = VectorMath.polynomialExpansion(tupel.getSecond().getExampleValues(), this.degree);
                    y[j][0] = tupel.getSecond().getExampleLabel();
                    weight[j] = tupel.getSecond().getExampleWeight();
                    ++j;
                }
                double maxDistance = Double.NEGATIVE_INFINITY;
                for (j = 0; j < distance.length; ++j) {
                    maxDistance = maxDistance < distance[j] ? distance[j] : maxDistance;
                }
                for (j = 0; j < distance.length; ++j) {
                    weight[j] = weight[j] * this.kernelSmoother.getWeight(distance[j], maxDistance);
                }
                double[] coefficients = LinearRegression.performRegression(new Matrix((double[][])x), new Matrix(y), weight, this.ridge);
                double[] probeExpaneded = VectorMath.polynomialExpansion(probe, this.degree);
                example.setPredictedLabel(VectorMath.vectorMultiplication(probeExpaneded, coefficients));
                continue;
            }
            if (localExamples.size() == 1) {
                example.setPredictedLabel(localExamples.iterator().next().getSecond().getExampleLabel());
                continue;
            }
            example.setPredictedLabel(Double.NaN);
        }
        return exampleSet;
    }

    @Override
    public String toString() {
        StringBuffer buffer = new StringBuffer();
        buffer.append("This model contains " + this.samples.size() + " examples for determining the neighborhood." + Tools.getLineSeparator());
        buffer.append("The fitted polynomial is of degree " + this.degree + " and is fitted with a ridge factor of " + this.ridge + Tools.getLineSeparator());
        buffer.append("It uses the " + this.neighborhood.toString() + " for neighborhood determination." + Tools.getLineSeparator());
        buffer.append("Weighting is performed using the " + this.kernelSmoother.toString());
        return buffer.toString();
    }

    public GeometricDataCollection<RegressionData> getSamples() {
        return this.samples;
    }

    public Neighborhood getNeighborhood() {
        return this.neighborhood;
    }

    public SmoothingKernel getKernelSmoother() {
        return this.kernelSmoother;
    }

    public int getDegree() {
        return this.degree;
    }

    public double getRidge() {
        return this.ridge;
    }

    public String[] getAttributeNames() {
        HeaderExampleSet trainSet = this.getTrainingHeader();
        Attributes attributes = trainSet.getAttributes();
        String[] attributeNames = new String[attributes.size()];
        int i = 0;
        for (Attribute attribute : attributes) {
            attributeNames[i] = attribute.getName();
            ++i;
        }
        return attributeNames;
    }

    public static class RegressionData
    implements Serializable {
        private static final long serialVersionUID = 8540161261369474329L;
        private double[] exampleValues;
        private double exampleLabel;
        private double exampleWeight;

        public RegressionData(double[] exampleValues, double exampleLabel, double exampleWeight) {
            this.exampleValues = exampleValues;
            this.exampleLabel = exampleLabel;
            this.exampleWeight = exampleWeight;
        }

        public double[] getExampleValues() {
            return this.exampleValues;
        }

        public double getExampleLabel() {
            return this.exampleLabel;
        }

        public double getExampleWeight() {
            return this.exampleWeight;
        }
    }
}

