/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.features.weighting;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorCreationException;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.features.weighting.AbstractWeighting;
import com.rapidminer.operator.preprocessing.discretization.BinDiscretization;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.OperatorService;
import com.rapidminer.tools.math.ContingencyTableTools;
import java.util.List;

public class ChiSquaredWeighting
extends AbstractWeighting {
    public ChiSquaredWeighting(OperatorDescription description) {
        super(description);
    }

    @Override
    protected AttributeWeights calculateWeights(ExampleSet exampleSet) throws OperatorException {
        Attribute label = exampleSet.getAttributes().getLabel();
        if (!label.isNominal()) {
            throw new UserError((Operator)this, 101, "chi squared test", label.getName());
        }
        BinDiscretization discretization = null;
        try {
            discretization = OperatorService.createOperator(BinDiscretization.class);
        }
        catch (OperatorCreationException e) {
            throw new UserError((Operator)this, 904, "Discretization", e.getMessage());
        }
        int numberOfBins = this.getParameterAsInt("number_of_bins");
        discretization.setParameter("number_of_bins", numberOfBins + "");
        exampleSet = discretization.doWork(exampleSet);
        int maximumNumberOfNominalValues = 0;
        for (Attribute attribute : exampleSet.getAttributes()) {
            if (!attribute.isNominal()) continue;
            maximumNumberOfNominalValues = Math.max(maximumNumberOfNominalValues, attribute.getMapping().size());
        }
        if (numberOfBins < maximumNumberOfNominalValues) {
            this.getLogger().warning("Number of bins too small, was " + numberOfBins + ". Set to maximum number of occurring nominal values (" + maximumNumberOfNominalValues + ")");
            numberOfBins = maximumNumberOfNominalValues;
        }
        double[][][] counters = new double[exampleSet.getAttributes().size()][numberOfBins][label.getMapping().size()];
        Attribute weightAttribute = exampleSet.getAttributes().getWeight();
        int exampleCounter = 0;
        double[] temporaryCounters = new double[label.getMapping().size()];
        for (Example example : exampleSet) {
            int labelIndex;
            double weight = 1.0;
            if (weightAttribute != null) {
                weight = example.getValue(weightAttribute);
            }
            int n = labelIndex = (int)example.getLabel();
            temporaryCounters[n] = temporaryCounters[n] + weight;
            ++exampleCounter;
        }
        for (int k = 0; k < counters.length; ++k) {
            for (int i = 0; i < temporaryCounters.length; ++i) {
                counters[k][0][i] = temporaryCounters[i];
            }
        }
        for (Example example : exampleSet) {
            int labelIndex = (int)example.getLabel();
            double weight = 1.0;
            if (weightAttribute != null) {
                weight = example.getValue(weightAttribute);
            }
            int attributeCounter = 0;
            for (Attribute attribute : exampleSet.getAttributes()) {
                int attributeIndex = (int)example.getValue(attribute);
                double[] dArray = counters[attributeCounter][attributeIndex];
                int n = labelIndex;
                dArray[n] = dArray[n] + weight;
                double[] dArray2 = counters[attributeCounter][0];
                int n2 = labelIndex;
                dArray2[n2] = dArray2[n2] - weight;
                ++attributeCounter;
            }
        }
        AttributeWeights weights = new AttributeWeights(exampleSet);
        int attributeCounter = 0;
        for (Attribute attribute : exampleSet.getAttributes()) {
            double weight = ContingencyTableTools.getChiSquaredStatistics(ContingencyTableTools.deleteEmpty(counters[attributeCounter]), false);
            weights.setWeight(attribute.getName(), weight);
            ++attributeCounter;
        }
        return weights;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeInt("number_of_bins", "The number of bins used for discretization of numerical attributes before the chi squared test can be performed.", 2, Integer.MAX_VALUE, 10));
        return types;
    }

    @Override
    public boolean supportsCapability(OperatorCapability capability) {
        switch (capability) {
            case BINOMINAL_ATTRIBUTES: 
            case POLYNOMINAL_ATTRIBUTES: 
            case NUMERICAL_ATTRIBUTES: 
            case BINOMINAL_LABEL: 
            case POLYNOMINAL_LABEL: {
                return true;
            }
        }
        return false;
    }
}

