/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.Partition;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.meta.MetaModel;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;

public class HierarchicalMultiClassModel
extends PredictionModel
implements MetaModel {
    private static final long serialVersionUID = -5792943818860734082L;
    private final Node root;

    public HierarchicalMultiClassModel(ExampleSet exampleSet, Node root) {
        super(exampleSet);
        this.root = root;
    }

    @Override
    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
        int i;
        ExampleSet applySet = (ExampleSet)exampleSet.clone();
        double[] confidences = new double[applySet.size()];
        int[] outcomes = new int[applySet.size()];
        int[] depths = new int[applySet.size()];
        Arrays.fill(outcomes, this.root.getPartitionId());
        Arrays.fill(confidences, 1.0);
        this.performPredictionRecursivly(applySet, this.root, confidences, outcomes, depths, 0, this.root.getPartitionId() + 1);
        Attribute labelAttribute = exampleSet.getAttributes().getLabel();
        int numberOfLabels = labelAttribute.getMapping().size();
        Attribute[] confidenceAttributes = new Attribute[numberOfLabels];
        for (i = 0; i < numberOfLabels; ++i) {
            confidenceAttributes[i] = exampleSet.getAttributes().getConfidence(labelAttribute.getMapping().mapIndex(i));
        }
        i = 0;
        for (Example example : exampleSet) {
            example.setValue(predictedLabel, outcomes[i]);
            double confidence = Math.pow(confidences[i], 1.0 / (double)depths[i]);
            double defaultConfidence = (1.0 - confidence) / (double)numberOfLabels;
            for (int j = 0; j < numberOfLabels; ++j) {
                example.setValue(confidenceAttributes[j], defaultConfidence);
            }
            example.setValue(confidenceAttributes[outcomes[i]], confidence);
            ++i;
        }
        return exampleSet;
    }

    private void performPredictionRecursivly(ExampleSet applySet, Node node, double[] confidences, int[] outcomes, int[] depths, int depth, int numberOfPartitions) throws OperatorException {
        if (!node.isLeaf()) {
            SplittedExampleSet splittedSet = new SplittedExampleSet(applySet, new Partition(outcomes, numberOfPartitions));
            splittedSet.selectSingleSubset(node.getPartitionId());
            ExampleSet currentResultSet = node.getModel().apply(splittedSet);
            int resultIndex = 0;
            Attribute predictionAttribute = currentResultSet.getAttributes().getPredictedLabel();
            for (Example example : currentResultSet) {
                int parentIndex = splittedSet.getActualParentIndex(resultIndex);
                String label = example.getValueAsString(predictionAttribute);
                int n = parentIndex;
                confidences[n] = confidences[n] * example.getConfidence(label);
                outcomes[parentIndex] = node.getChild(label).getPartitionId();
                depths[parentIndex] = depth;
                ++resultIndex;
            }
            PredictionModel.removePredictedLabel(currentResultSet);
            for (Node child : node.getChildren()) {
                this.performPredictionRecursivly(applySet, child, confidences, outcomes, depths, depth + 1, numberOfPartitions);
            }
        }
    }

    @Override
    public List<String> getModelNames() {
        LinkedList<Node> nodes = new LinkedList<Node>();
        this.collectNodes(this.root, nodes);
        ArrayList<String> names = new ArrayList<String>(nodes.size());
        for (Node node : nodes) {
            if (node.isLeaf()) continue;
            names.add(node.getClassName());
        }
        return names;
    }

    private void collectNodes(Node node, List<Node> nodes) {
        nodes.add(node);
        if (!node.isLeaf()) {
            for (Node child : node.getChildren()) {
                this.collectNodes(child, nodes);
            }
        }
    }

    public List<Model> getModels() {
        LinkedList<Node> nodes = new LinkedList<Node>();
        this.collectNodes(this.root, nodes);
        ArrayList<Model> names = new ArrayList<Model>(nodes.size());
        for (Node node : nodes) {
            if (node.isLeaf()) continue;
            names.add(node.getModel());
        }
        return names;
    }

    public static class Node
    implements Serializable {
        private static final long serialVersionUID = 1L;
        private final String className;
        private int partitionId;
        private final LinkedHashMap<String, Node> children = new LinkedHashMap();
        private final List<Node> childrenList = new ArrayList<Node>();
        private Node parent = null;
        private Model model = null;

        public Node(String className) {
            this.className = className;
        }

        public List<Node> getChildren() {
            return this.childrenList;
        }

        public void addChild(Node child) {
            this.childrenList.add(child);
            this.children.put(child.getClassName(), child);
            child.setParent(this);
        }

        public void setParent(Node parent) {
            this.parent = parent;
        }

        public boolean isRoot() {
            return this.parent == null;
        }

        public void setPartitionId(int partition) {
            this.partitionId = partition;
        }

        public int getPartitionId() {
            return this.partitionId;
        }

        public Node getParent() {
            return this.parent;
        }

        public String getClassName() {
            return this.className;
        }

        public boolean isLeaf() {
            return this.children.isEmpty();
        }

        public void setModel(Model model) {
            this.model = model;
        }

        public Model getModel() {
            return this.model;
        }

        public Node getChild(String label) {
            return this.children.get(label);
        }
    }
}

