package com.rapidminer.operator.learner.meta;

import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeRole;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.AttributeValueFilterSingleCondition;
import com.rapidminer.example.set.ConditionedExampleSet;
import com.rapidminer.example.table.DataRow;
import com.rapidminer.example.table.DataRowFactory;
import com.rapidminer.example.table.MemoryExampleTable;
import com.rapidminer.operator.AbstractModel;
import com.rapidminer.operator.MissingIOObjectException;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorException;

/**
 * A model which contains specialized models for a number of known clusters.
 * The input example set must contain a cluster attribute. For prediction,
 * the specialized models are applied to the corresponding subsets. Examples
 * in clusters for which no specialized model exists are labeled as "unknown".
 * 
 * The cluster attribute must be nominal, because the clusters are identified
 * by their names.
 * 
 * @author Marius Helf
 *
 */
public class CombinedModelOnCluster extends AbstractModel {
	private static final long serialVersionUID = 2907982539494087150L;
	
	private Map<String,Model> models;

	/**
	 * Creates a new instance of CombinedModelOnCluster
	 * @param exampleSet the example set on which the model has been trained
	 * @param models a mapping of cluster names to their specialized models
	 */
	public CombinedModelOnCluster(ExampleSet exampleSet, Map<String,Model> models) {
		super(exampleSet);
		this.models = models;
	}

	@Override
	public ExampleSet apply(ExampleSet exampleSet) throws OperatorException {
		Attribute clusterAttribute = exampleSet.getAttributes().getCluster();
		if ( clusterAttribute == null ) {
			// if clusterAttribute is null, the example set obviously does not contain a
			// special cluster attribute
			throw new OperatorException("example set has no special cluster attribute");
		}

		
		List<ExampleSet> predictedExampleSets = new LinkedList<ExampleSet>();
		// setting values
		for (String currentClusterName : clusterAttribute.getMapping().getValues()) {
			Model model = models.get(currentClusterName);
			AttributeValueFilterSingleCondition condition = new AttributeValueFilterSingleCondition(
					clusterAttribute, 
					AttributeValueFilterSingleCondition.EQUALS, 
					currentClusterName);
			ConditionedExampleSet currentClusterSet = new ConditionedExampleSet((ExampleSet)exampleSet.clone(), condition);
			if ( model != null && currentClusterSet.size() > 0 ) {
				ExampleSet predictedExampleSet = model.apply(currentClusterSet);
				predictedExampleSets.add(predictedExampleSet);
				
			}
			else if ( currentClusterSet.size() > 0 ) {
				logWarning("The model for cluster '" + currentClusterName + "' is null. Examples in this cluster will be classified as unknown.");
				// no model present.
				// TODO set all predictions to "unknown" and add the currentClusterSet to predictedExampleSets
			}
		}
		return merge(predictedExampleSets);
	}

	ExampleSet merge(List<ExampleSet> allExampleSets) throws OperatorException {
		// throw error if no example sets were available
		if (allExampleSets.size() == 0)
			throw new MissingIOObjectException(ExampleSet.class);

		// create new example table
		ExampleSet firstSet = allExampleSets.get(0);
		List<Attribute> attributeList = new LinkedList<Attribute>();
		Map<Attribute, String> specialAttributes = new HashMap<Attribute, String>();
		Iterator<AttributeRole> a = firstSet.getAttributes().allAttributeRoles();
		while (a.hasNext()) {
			AttributeRole role = a.next();
			Attribute attributeClone = (Attribute) role.getAttribute().clone();
			attributeList.add(attributeClone);
			if (role.isSpecial()) {
				specialAttributes.put(attributeClone, role.getSpecialName());
			}
		}
		MemoryExampleTable exampleTable = new MemoryExampleTable(attributeList);

		Iterator<ExampleSet> i = allExampleSets.iterator();
		DataRowFactory factory = new DataRowFactory(DataRowFactory.TYPE_DOUBLE_ARRAY, '.');
		while (i.hasNext()) {
			ExampleSet currentExampleSet = i.next();
			Iterator<Example> e = currentExampleSet.iterator();
			while (e.hasNext()) {
				DataRow dataRow = e.next().getDataRow();
				String[] newData = new String[attributeList.size()];
				// Iterator<Attribute> oldAttributes = currentExampleSet.getAttributes().allAttributes();
				Iterator<Attribute> newAttributes = attributeList.iterator();
				int counter = 0;
				while (newAttributes.hasNext()) {
					// Attribute oldAttribute = oldAttributes.next();
					Attribute newAttribute = newAttributes.next();
					Attribute oldAttribute = currentExampleSet.getAttributes().get(newAttribute.getName());
					double oldValue = dataRow.get(oldAttribute);
					if (Double.isNaN(oldValue)) {
						newData[counter] = Attribute.MISSING_NOMINAL_VALUE;
					} else {
						if (newAttribute.isNominal()) {
							newData[counter] = oldAttribute.getMapping().mapIndex((int) oldValue);
						} else {
							newData[counter] = oldValue + "";
						}
					}
					counter++;
				}
				exampleTable.addDataRow(factory.create(newData, exampleTable.getAttributes()));
//				checkForStop();
			}
		}

		// create result example set
		ExampleSet resultSet = exampleTable.createExampleSet(specialAttributes);
		return resultSet;
	}

}
