/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.Partition;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.meta.AbstractMetaLearner;
import com.rapidminer.operator.learner.meta.Binary2MultiClassModel;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.operator.ports.metadata.PredictionModelMetaData;
import com.rapidminer.operator.ports.metadata.SetRelation;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.conditions.EqualTypeCondition;
import com.rapidminer.tools.RandomGenerator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public class Binary2MultiClassLearner
extends AbstractMetaLearner {
    public static final String PARAMETER_CLASSIFICATION_STRATEGIES = "classification_strategies";
    public static final String PARAMETER_RANDOM_CODE_MULTIPLICATOR = "random_code_multiplicator";
    private static final String[] STRATEGIES = new String[]{"1 against all", "1 against 1", "exhaustive code (ECOC)", "random code (ECOC)"};
    private static final int ONE_AGAINST_ALL = 0;
    private static final int ONE_AGAINST_ONE = 1;
    private static final int EXHAUSTIVE_CODE = 2;
    private static final int RANDOM_CODE = 3;
    private final LinkedList<String> modelNames = new LinkedList();

    public Binary2MultiClassLearner(OperatorDescription description) {
        super(description);
    }

    @Override
    protected MetaData modifyExampleSetMetaData(ExampleSetMetaData metaData) {
        AttributeMetaData labelAMD = metaData.getAttributeByRole("label");
        if (labelAMD != null && labelAMD.isNominal()) {
            labelAMD.setType(6);
            labelAMD.setValueSetRelation(SetRelation.SUBSET);
        }
        return metaData;
    }

    @Override
    protected MetaData modifyGeneratedModelMetaData(PredictionModelMetaData unmodifiedMetaData) {
        for (AttributeMetaData amd : unmodifiedMetaData.getPredictionAttributeMetaData()) {
            if (!amd.getRole().equals("prediction")) continue;
            amd.setType(1);
        }
        return super.modifyGeneratedModelMetaData(unmodifiedMetaData);
    }

    private SplittedExampleSet constructClassPartitionSet(ExampleSet inputSet) {
        Attribute classLabel = inputSet.getAttributes().getLabel();
        int numberOfClasses = classLabel.getMapping().size();
        int[] examples = new int[inputSet.size()];
        Iterator exampleIterator = inputSet.iterator();
        int i = 0;
        while (exampleIterator.hasNext()) {
            Example e = (Example)exampleIterator.next();
            examples[i] = (int)e.getValue(classLabel);
            ++i;
        }
        Partition separatedClasses = new Partition(examples, numberOfClasses);
        return new SplittedExampleSet((ExampleSet)inputSet.clone(), separatedClasses);
    }

    private Model[] applyCodePattern(SplittedExampleSet seSet, Attribute classLabel, CodePattern codePattern) throws OperatorException {
        int numberOfClasses = classLabel.getMapping().size();
        int numberOfFunctions = codePattern.data[0].length;
        Model[] models = new Model[numberOfFunctions];
        HashMap<Integer, Integer> classIndexMap = new HashMap<Integer, Integer>(numberOfClasses);
        for (int currentFunction = 0; currentFunction < numberOfFunctions; ++currentFunction) {
            int counter = 0;
            seSet.clearSelection();
            for (String currentClass : classLabel.getMapping().getValues()) {
                classIndexMap.put(classLabel.getMapping().mapString(currentClass), counter);
                if (codePattern.partitionEnabled[counter][currentFunction]) {
                    seSet.selectAdditionalSubset(classLabel.getMapping().mapString(currentClass));
                }
                ++counter;
            }
            Attribute workingLabel = AttributeFactory.createAttribute("multiclass_working_label", 6);
            seSet.getExampleTable().addAttribute(workingLabel);
            seSet.getAttributes().addRegular(workingLabel);
            int currentIndex = 0;
            for (Example e : seSet) {
                currentIndex = (Integer)classIndexMap.get((int)e.getValue(classLabel));
                if (!codePattern.partitionEnabled[currentIndex][currentFunction]) continue;
                e.setValue(workingLabel, workingLabel.getMapping().mapString(codePattern.data[currentIndex][currentFunction]));
            }
            seSet.getAttributes().remove(workingLabel);
            seSet.getAttributes().setLabel(workingLabel);
            models[currentFunction] = this.applyInnerLearner(seSet);
            this.inApplyLoop();
            seSet.getAttributes().setLabel(classLabel);
            seSet.getExampleTable().removeAttribute(workingLabel);
        }
        return models;
    }

    private CodePattern buildCodePattern_ONE_VS_ALL(Attribute classLabel) {
        int numberOfClasses = classLabel.getMapping().size();
        CodePattern codePattern = new CodePattern(numberOfClasses, numberOfClasses);
        Iterator<String> classIt = classLabel.getMapping().getValues().iterator();
        this.modelNames.clear();
        for (int i = 0; i < numberOfClasses; ++i) {
            for (int j = 0; j < numberOfClasses; ++j) {
                if (i == j) {
                    String currentClass = classIt.next();
                    this.modelNames.add(currentClass + " vs. all other");
                    codePattern.data[i][j] = currentClass;
                    continue;
                }
                codePattern.data[i][j] = "all_other_classes";
            }
        }
        return codePattern;
    }

    private CodePattern buildCodePattern_ONE_VS_ONE(Attribute classLabel) {
        int numberOfClasses = classLabel.getMapping().size();
        int numberOfCombinations = numberOfClasses * (numberOfClasses - 1) / 2;
        String[] classIndexMap = new String[numberOfClasses];
        CodePattern codePattern = new CodePattern(numberOfClasses, numberOfCombinations);
        this.modelNames.clear();
        for (int i = 0; i < numberOfClasses; ++i) {
            for (int j = 0; j < numberOfCombinations; ++j) {
                codePattern.partitionEnabled[i][j] = false;
            }
        }
        int classIndex = 0;
        Iterator<String> i$ = classLabel.getMapping().getValues().iterator();
        while (i$.hasNext()) {
            String className;
            classIndexMap[classIndex] = className = i$.next();
            ++classIndex;
        }
        int currentClassA = 0;
        int currentClassB = 1;
        for (int counter = 0; counter < numberOfCombinations; ++counter) {
            if (currentClassB > numberOfClasses - 1) {
                currentClassB = ++currentClassA + 1;
            }
            if (currentClassA > numberOfClasses - 2) break;
            codePattern.partitionEnabled[currentClassA][counter] = true;
            codePattern.partitionEnabled[currentClassB][counter] = true;
            String currentClassNameA = classIndexMap[currentClassA];
            String currentClassNameB = classIndexMap[currentClassB];
            codePattern.data[currentClassA][counter] = currentClassNameA;
            codePattern.data[currentClassB][counter] = currentClassNameB;
            this.modelNames.add(currentClassNameA + " vs. " + currentClassNameB);
            ++currentClassB;
        }
        return codePattern;
    }

    private CodePattern buildCodePattern_EXHAUSTIVE_CODE(Attribute classLabel) {
        int i;
        int numberOfClasses = classLabel.getMapping().size();
        int numberOfFunctions = (int)Math.pow(2.0, numberOfClasses - 1) - 1;
        CodePattern codePattern = new CodePattern(numberOfClasses, numberOfFunctions);
        for (i = 0; i < numberOfFunctions; ++i) {
            codePattern.data[0][i] = "true";
        }
        for (i = 1; i < numberOfClasses; ++i) {
            int currentStep = (int)Math.pow(2.0, numberOfClasses - (i + 1));
            for (int j = 0; j < numberOfFunctions; ++j) {
                codePattern.data[i][j] = "" + (j / currentStep % 2 > 0);
            }
        }
        return codePattern;
    }

    private CodePattern buildCodePattern_RANDOM_CODE(Attribute classLabel) throws OperatorException {
        int i;
        double multiplicator = this.getParameterAsDouble(PARAMETER_RANDOM_CODE_MULTIPLICATOR);
        int numberOfClasses = classLabel.getMapping().size();
        CodePattern codePattern = new CodePattern(numberOfClasses, (int)((double)numberOfClasses * multiplicator));
        RandomGenerator randomGenerator = RandomGenerator.getRandomGenerator(this);
        for (i = 0; i < codePattern.data.length; ++i) {
            for (int j = 0; j < codePattern.data[0].length; ++j) {
                codePattern.data[i][j] = "" + randomGenerator.nextBoolean();
            }
        }
        for (i = 0; i < codePattern.data[0].length; ++i) {
            boolean containsNoOne = true;
            boolean containsNoZero = true;
            for (int j = 0; j < codePattern.data.length; ++j) {
                if ("true".equals(codePattern.data[j][i])) {
                    containsNoOne = false;
                    continue;
                }
                containsNoZero = false;
            }
            if (containsNoOne) {
                codePattern.data[(int)(randomGenerator.nextDouble() * (double)(codePattern.data.length - 1))][i] = "true";
            }
            if (!containsNoZero) continue;
            codePattern.data[(int)(randomGenerator.nextDouble() * (double)(codePattern.data.length - 1))][i] = "false";
        }
        return codePattern;
    }

    @Override
    public Model learn(ExampleSet inputSet) throws OperatorException {
        Attribute classLabel = inputSet.getAttributes().getLabel();
        if (classLabel.getMapping().size() == 2) {
            return this.applyInnerLearner(inputSet);
        }
        int classificationStrategy = this.getParameterAsInt(PARAMETER_CLASSIFICATION_STRATEGIES);
        SplittedExampleSet seSet = this.constructClassPartitionSet(inputSet);
        switch (classificationStrategy) {
            case 0: {
                this.getLogger().fine("Binary2MultiCLassLearner set to <<1-vs-all>>");
                CodePattern codePattern = this.buildCodePattern_ONE_VS_ALL(classLabel);
                Model[] models = this.applyCodePattern(seSet, classLabel, codePattern);
                return new Binary2MultiClassModel(inputSet, models, classificationStrategy, this.modelNames);
            }
            case 1: {
                this.getLogger().fine("Binary2MultiCLassLearner set to <<1-vs-1>>");
                CodePattern codePattern = this.buildCodePattern_ONE_VS_ONE(classLabel);
                Model[] models = this.applyCodePattern(seSet, classLabel, codePattern);
                return new Binary2MultiClassModel(inputSet, models, classificationStrategy, this.modelNames);
            }
            case 2: {
                this.getLogger().fine("Binary2MultiCLassLearner set to <<exhaustive code>>");
                CodePattern codePattern = this.buildCodePattern_EXHAUSTIVE_CODE(classLabel);
                Model[] models = this.applyCodePattern(seSet, classLabel, codePattern);
                return new Binary2MultiClassModel(inputSet, models, classificationStrategy, codePattern.data);
            }
            case 3: {
                this.getLogger().fine("Binary2MultiCLassLearner set to <<random code>>");
                CodePattern codePattern = this.buildCodePattern_RANDOM_CODE(classLabel);
                Model[] models = this.applyCodePattern(seSet, classLabel, codePattern);
                return new Binary2MultiClassModel(inputSet, models, classificationStrategy, codePattern.data);
            }
        }
        throw new OperatorException("Binary2MultiCLassLearner: Unknown classification strategy selected");
    }

    @Override
    public boolean supportsCapability(OperatorCapability capability) {
        switch (capability) {
            case NUMERICAL_LABEL: 
            case NO_LABEL: 
            case UPDATABLE: 
            case FORMULA_PROVIDER: {
                return false;
            }
        }
        return true;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeCategory(PARAMETER_CLASSIFICATION_STRATEGIES, "What strategy should be used for multi class classifications?", STRATEGIES, 0, false));
        ParameterTypeDouble type = new ParameterTypeDouble(PARAMETER_RANDOM_CODE_MULTIPLICATOR, "A multiplicator regulating the codeword length in random code modus.", 1.0, Double.POSITIVE_INFINITY, 2.0, false);
        type.registerDependencyCondition(new EqualTypeCondition(this, PARAMETER_CLASSIFICATION_STRATEGIES, STRATEGIES, true, 2, 3));
        types.add(type);
        types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return types;
    }

    private static class CodePattern {
        String[][] data;
        boolean[][] partitionEnabled;

        public CodePattern(int numberOfClasses, int numberOfFunctions) {
            this.data = new String[numberOfClasses][numberOfFunctions];
            this.partitionEnabled = new boolean[numberOfClasses][numberOfFunctions];
            for (int i = 0; i < numberOfClasses; ++i) {
                for (int j = 0; j < numberOfFunctions; ++j) {
                    this.partitionEnabled[i][j] = true;
                }
            }
        }
    }
}

