/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.Partition;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.table.NominalMapping;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.meta.AbstractMetaLearner;
import com.rapidminer.operator.learner.meta.HierarchicalMultiClassModel;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeList;
import com.rapidminer.parameter.ParameterTypeString;
import com.rapidminer.tools.RandomGenerator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

public class HierarchicalMultiClassLearner
extends AbstractMetaLearner {
    public static final String PARAMETER_HIERARCHY = "hierarchy";
    public static final String PARAMETER_PARENT_CLASS = "parent_class";
    public static final String PARAMETER_CHILD_CLASS = "child_class";

    public HierarchicalMultiClassLearner(OperatorDescription description) {
        super(description);
    }

    @Override
    public Model learn(ExampleSet inputSet) throws OperatorException {
        Attribute labelAttribute = inputSet.getAttributes().getLabel();
        this.checkCompatibility(labelAttribute);
        List<String[]> hierarchyEntryPairs = this.getParameterList(PARAMETER_HIERARCHY);
        HashMap<String, HierarchicalMultiClassModel.Node> nodeMap = new HashMap<String, HierarchicalMultiClassModel.Node>();
        HashSet<HierarchicalMultiClassModel.Node> innerNodes = new HashSet<HierarchicalMultiClassModel.Node>();
        for (String[] entryPair : hierarchyEntryPairs) {
            HierarchicalMultiClassModel.Node childNode;
            String parentClass = entryPair[0];
            String childClass = entryPair[1];
            HierarchicalMultiClassModel.Node parentNode = (HierarchicalMultiClassModel.Node)nodeMap.get(parentClass);
            if (parentNode == null) {
                parentNode = new HierarchicalMultiClassModel.Node(parentClass);
            }
            if ((childNode = (HierarchicalMultiClassModel.Node)nodeMap.get(childClass)) == null) {
                childNode = new HierarchicalMultiClassModel.Node(childClass);
            }
            parentNode.addChild(childNode);
            nodeMap.put(parentClass, parentNode);
            nodeMap.put(childClass, childNode);
            innerNodes.add(childNode);
        }
        HierarchicalMultiClassModel.Node root = null;
        for (HierarchicalMultiClassModel.Node node : nodeMap.values()) {
            if (innerNodes.contains(node)) continue;
            if (root == null) {
                root = node;
                continue;
            }
            throw new UserError((Operator)this, 220, root.getClassName(), node.getClassName());
        }
        if (root == null) {
            throw new UserError((Operator)this, 221);
        }
        for (HierarchicalMultiClassModel.Node node : nodeMap.values()) {
            if (node.getChildren().size() != 1) continue;
            throw new UserError((Operator)this, 222, node.getClassName(), node.getChildren().size());
        }
        this.computeModel(root, inputSet, labelAttribute);
        return new HierarchicalMultiClassModel(inputSet, root);
    }

    private void checkCompatibility(Attribute labelAttribute) throws UserError {
        HashSet<String> values = new HashSet<String>(labelAttribute.getMapping().getValues());
        for (String[] pair : this.getParameterList(PARAMETER_HIERARCHY)) {
            values.add(pair[0]);
        }
        String rootValue = null;
        for (String[] pair : this.getParameterList(PARAMETER_HIERARCHY)) {
            if (!values.contains(pair[1])) {
                throw new UserError((Operator)this, 219, pair[1]);
            }
            if (values.contains(pair[0])) continue;
            if (rootValue == null) {
                rootValue = pair[0];
                continue;
            }
            throw new UserError((Operator)this, 220, pair[0], rootValue);
        }
    }

    private void computeModel(HierarchicalMultiClassModel.Node rootNode, ExampleSet exampleSet, Attribute originalLabel) throws OperatorException {
        exampleSet.getAttributes().setSpecialAttribute(originalLabel, "label_original");
        Attribute workingLabel = AttributeFactory.createAttribute(originalLabel.getName() + "_working", originalLabel.getValueType());
        exampleSet.getExampleTable().addAttribute(workingLabel);
        exampleSet.getAttributes().addRegular(workingLabel);
        exampleSet.getAttributes().setLabel(workingLabel);
        int[] partitions = new int[exampleSet.size()];
        int i = 0;
        int lastLeafId = -1;
        for (Example example : exampleSet) {
            double value = example.getValue(originalLabel);
            example.setValue(workingLabel, value);
            partitions[i] = (int)value;
            if (partitions[i] > lastLeafId) {
                lastLeafId = partitions[i];
            }
            ++i;
        }
        AtomicInteger nonLeafCounter = new AtomicInteger(lastLeafId);
        this.setParitionIdRecursivly(rootNode, nonLeafCounter, lastLeafId, workingLabel);
        this.computeModelRecursivly(rootNode, partitions, nonLeafCounter.get(), exampleSet);
        exampleSet.getAttributes().remove(workingLabel);
        exampleSet.getAttributes().setLabel(originalLabel);
        exampleSet.getExampleTable().removeAttribute(workingLabel);
    }

    private void setParitionIdRecursivly(HierarchicalMultiClassModel.Node node, AtomicInteger nonLeafCounter, int maxLeafId, Attribute workingLabel) {
        if (node.isLeaf()) {
            node.setPartitionId(workingLabel.getMapping().mapString(node.getClassName()));
        } else {
            for (HierarchicalMultiClassModel.Node child : node.getChildren()) {
                this.setParitionIdRecursivly(child, nonLeafCounter, maxLeafId, workingLabel);
                node.setPartitionId(nonLeafCounter.incrementAndGet());
            }
        }
    }

    private void computeModelRecursivly(HierarchicalMultiClassModel.Node node, int[] partitions, int numberOfPartitions, ExampleSet exampleSet) throws OperatorException {
        if (node.isLeaf()) {
            return;
        }
        for (HierarchicalMultiClassModel.Node child : node.getChildren()) {
            this.computeModelRecursivly(child, partitions, numberOfPartitions, exampleSet);
        }
        SplittedExampleSet trainSet = new SplittedExampleSet(exampleSet, new Partition(partitions, numberOfPartitions));
        Attribute workingLabel = trainSet.getAttributes().getLabel();
        workingLabel.setMapping((NominalMapping)workingLabel.getMapping().clone());
        workingLabel.getMapping().clear();
        for (HierarchicalMultiClassModel.Node child : node.getChildren()) {
            trainSet.selectSingleSubset(child.getPartitionId());
            int nodeLabelIndex = workingLabel.getMapping().mapString(child.getClassName());
            for (Example example : trainSet) {
                example.setValue(workingLabel, nodeLabelIndex);
            }
        }
        trainSet.clearSelection();
        for (HierarchicalMultiClassModel.Node child : node.getChildren()) {
            trainSet.selectAdditionalSubset(child.getPartitionId());
        }
        Model model = this.applyInnerLearner(trainSet);
        node.setModel(model);
        int partitionId = node.getPartitionId();
        for (HierarchicalMultiClassModel.Node child : node.getChildren()) {
            int childPartitionId = child.getPartitionId();
            for (int i = 0; i < partitions.length; ++i) {
                if (partitions[i] != childPartitionId) continue;
                partitions[i] = partitionId;
            }
        }
    }

    @Override
    public boolean supportsCapability(OperatorCapability capability) {
        switch (capability) {
            case NUMERICAL_LABEL: 
            case NO_LABEL: 
            case UPDATABLE: 
            case FORMULA_PROVIDER: 
            case BINOMINAL_LABEL: 
            case ONE_CLASS_LABEL: {
                return false;
            }
        }
        return true;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeList(PARAMETER_HIERARCHY, "The hierarchy...", (ParameterType)new ParameterTypeString(PARAMETER_PARENT_CLASS, "The parent class.", false), (ParameterType)new ParameterTypeString(PARAMETER_CHILD_CLASS, "The child class.", false)));
        types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return types;
    }
}

