/*
 * MiningMart Version 1.0
 * 
 * Copyright (C) 2006 Martin Scholz, Timm Euler, 
 *                    Daniel Hakenjos, Katharina Morik
 *
 * Contact: miningmart@ls8.cs.uni-dortmund.de
 *
 * A list of contributing developers (other than the copyright 
 * holders) can be found at
 * http://mmart.cs.uni-dortmund.de/downloads/download.html
 * 
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program, see the file MM_HOME/LICENSE; if not, write
 * to the Free Software Foundation, Inc., 51 Franklin Street, Fifth
 * Floor, Boston, MA 02110-1301, USA.
 */
package edu.udo.cs.miningmart.operator;

import java.util.Collection;
import java.util.Iterator;
import java.util.Vector;

import edu.udo.cs.miningmart.compiler.wrapper.DB_SVM_CL;
import edu.udo.cs.miningmart.compiler.wrapper.SVM_CL;
import edu.udo.cs.miningmart.compiler.wrapper.SVM_Wrapper;

import edu.udo.cs.miningmart.exception.M4CompilerError;
import edu.udo.cs.miningmart.exception.M4Exception;
import edu.udo.cs.miningmart.m4.BaseAttribute;
import edu.udo.cs.miningmart.m4.Column;
import edu.udo.cs.miningmart.m4.Columnset;
import edu.udo.cs.miningmart.m4.Feature;
import edu.udo.cs.miningmart.m4.MultiColumnFeature;
import edu.udo.cs.miningmart.m4.Value;
import edu.udo.cs.miningmart.m4.utils.Print;

/**
 * This operator uses the SVM algorithm mySVM/db, which estimates the
 * generalisation performance of an SVM on different feature subsets,
 * to choose the best feature subset.
 * 
 * @author Timm Euler
 * @version $Id: FeatureSelectionWithSVM.java,v 1.5 2006/04/11 14:10:11 euler Exp $
 */
public class FeatureSelectionWithSVM extends FeatureSelection
{
    private SVM_Wrapper mysvm = null;
    private boolean useDatabaseSVM = false;   // default; overridden by parameter 'UseDB_SVM'
    										
	/**
	 * @see edu.udo.cs.miningmart.m4.core.operator.FeatureSelection#computeListOfAttributes()
	 */
	protected Collection computeListOfAttributes(Feature[] theSuperset) throws M4CompilerError
	{	
		this.useDatabaseSVM = this.mustUseDatabaseSVM();
		this.prepareSVM();
		
		Feature[] bestSubset = this.search(theSuperset);
		
		// turn bestSubset into a collection, here a vector:
		Vector v = new Vector();
		for (int i = 0; i < bestSubset.length; i++)
		{   v.add(bestSubset[i]);   }
		v.trimToSize();
		
		return v;				
	}

	private void prepareSVM()
	{		
	    try
        {
        	BaseAttribute keyBA = this.getKey();
       		String key = null;
       		if (keyBA != null)
       		{  key = keyBA.getCurrentColumn().getName();  }        	
        	
        	Value v = this.getPositiveTarget();
        	if (v == null)
        	{  throw new M4CompilerError("Operator FeatureSelectionWithSVM: Need parameter 'PositiveTargetValue'!");  }
        	
        	String positiveTargetValue = v.getValue();       
        	
        	Value sa = this.getSampleSize();
        	
        	if (sa == null)
        	{ 	 
        		if ( ! this.useDatabaseSVM)
        		{  throw new M4CompilerError("FeatureSelectionWithSVM: parameter 'SampleSize' must be specified if SVM outside database is used!");  }       	
        		
	        	this.mysvm = new DB_SVM_CL( this.getM4Db(),
	        								this.getM4Db().getCasePrintObject(),
	        								this.getPrefixForDatabaseObjects(),
	        								this.getDatabaseSchema(),
	        								this.getStep().getId(),
	        								key,
	        								positiveTargetValue);
        	}
        	else
        	{ 	        	
        		long sample;
        		try
		        {   sample = Long.parseLong(sa.getValue());  }
			    catch (NumberFormatException nfe)
	        	{  throw new M4CompilerError("FeatureSelectionWithSVM: SampleSize must be an integer! Found: " + v.getValue());  }
        		
        		if (this.useDatabaseSVM)
        		{
		        	this.mysvm = new DB_SVM_CL( this.getM4Db(),
		        								this.getM4Db().getCasePrintObject(),
		        								this.getPrefixForDatabaseObjects(),
		        								this.getDatabaseSchema(),
		        								this.getStep().getId(),
		        								key,
			       								sample,
	        									positiveTargetValue);
        		}
        		else
        		{
        			this.mysvm = new SVM_CL( this.getM4Db(), 
        									 this.getM4Db().getCasePrintObject(),
        									 this.getPrefixForDatabaseObjects(),
        									 this.getDatabaseSchema(), 
        									 this.getStep().getId(),
        									 sample, 
        									 positiveTargetValue);        			
        		}
        	}
	    }
	    catch (java.lang.Exception e)
        {   this.doPrint(Print.MAX, "Error with SVM for FeatureSelection: " + e.getMessage());  }
	}
	
	// returns a string that is used by the wrapper to create 
	// database objects (whose name is prefixed by this string)
	private String getPrefixForDatabaseObjects() throws M4CompilerError
	{        
		try {
	        return this.getInputConcept().getCurrentColumnSet().getName();
		}
   		catch (M4Exception m4e)
   		{   throw new M4CompilerError("M4 interface error in " + this.getName() + ": " + m4e.getMessage());  }    		
	}
	
	// returns the name of the database schema in which the input table or view lives
	protected String getDatabaseSchema() throws M4CompilerError {
		try {
			return this.getInputConcept().getCurrentColumnSet().getSchema();
		}
		catch (M4Exception m4e)
   		{   throw new M4CompilerError("M4 interface error in " + this.getName() + ": " + m4e.getMessage());  } 
	}
	
	private Feature[] search(Feature[] theSet) throws M4CompilerError
	{
		// use Vectors for searching:
		Vector all = new Vector();
		for (int i = 0; i < theSet.length; i++)
		{	all.add(theSet[i]);	}
		all.trimToSize();
		Vector best;
		
		if (this.getDirection().getValue().equalsIgnoreCase("forward"))
		{
			Vector empty = new Vector(0);
			best = this.searchForward(empty, all, -1);
		}
		else
		{
			best = this.searchBackward(all, -1);
		}
		// return an array:
		Feature[] bestSubset = new Feature[best.size()];
		String theNames = "";
		for (int i = 0; i < bestSubset.length; i++)
		{
			bestSubset[i] = (Feature) best.get(i);
			theNames += bestSubset[i].getName() + ", ";
		}
		theNames = theNames.substring(0, theNames.length() - 2);
		
		this.doPrint(Print.OPERATOR, "FSwithSVM: selected the following features: " + theNames);
		
		return bestSubset;
	}
	
	private Vector searchBackward(Vector current, double currentPerformance) throws M4CompilerError
	{
		// stop criterion:
		if (current.size() == 1)
		{   return current;   }
		
		// evaluate current if it's not done yet
		if (currentPerformance == -1)
		{   
			this.doPrint(Print.OPERATOR, "FSwithSVM: evaluating all features...");
			currentPerformance = this.getSVMEstimation(current);  
		}
		
		// evaluate all daughters of current
		Vector daughter;
		Vector bestDaughter = current;   
		double bestPerformance = -1;
		double daughterPerformance;
		
		this.doPrint(Print.OPERATOR, "Trying all sets of " + (current.size() - 1) + " features.");
		
		for (int i = 0; i < current.size(); i++)
		{
			daughter = this.removeIndexedFeature(current, i);
			daughterPerformance = this.getSVMEstimation(daughter);  
			if (daughterPerformance > bestPerformance)
			{  
				bestPerformance = daughterPerformance;  
				bestDaughter = daughter;
			}	
		}
		if (bestPerformance >= currentPerformance)
		{
			// choose best daughter, call again
			return searchBackward(bestDaughter, bestPerformance);	
		}
		else
		{
			return current; 
		}
	}
	
	private Vector searchForward(Vector current, Vector all, double currentPerformance) throws M4CompilerError
	{
		// one stop criterion:
		if (current.size() == all.size())
		{   return current;   }
		
		// evaluate current if it's not done yet
		if ((currentPerformance == -1) && ( ! current.isEmpty()))
		{   currentPerformance = this.getSVMEstimation(current);  }
		
		// evaluate all daughters of current
		Vector daughter;
		Vector bestDaughter = current;
		double bestPerformance = -1;
		double daughterPerformance;
		
		this.doPrint(Print.OPERATOR, "FSwithSVM: forward search on level " + current.size() + "...");
		
		for (int i = 0; i < all.size(); i++)
		{
			daughter = this.addIndexedFeature(current, all, i);
			
			// if feature i was in current anyway, try i+1:
			if (daughter == null) 
			{  continue;  }
			
			daughterPerformance = this.getSVMEstimation(daughter);  
			if (daughterPerformance > bestPerformance)
			{  
				bestPerformance = daughterPerformance;  
				bestDaughter = daughter;
			}	
		}
		if (bestPerformance > currentPerformance)
		{
			// choose best daughter, call again
			this.doPrint(Print.OPERATOR, "FSwithSVM: best performance on this level was estimated: " + bestPerformance);
			
			return searchForward(bestDaughter, all, bestPerformance);	
		}
		else
		{
			this.doPrint(Print.OPERATOR, "FSwithSVM: returned from this level with estimation: " + currentPerformance);
			return current; 
		}
	}
	
	// returns the XiAlpha estimation of mySVM/db, where the SVM
	// is trained on the features in the given vector
	private double getSVMEstimation(Vector theFeatures) throws M4CompilerError
	{
		try {
			Columnset theColumnSet = this.getInputConcept().getCurrentColumnSet();
	        
	        Vector theCols = new Vector();
	        Feature f; BaseAttribute ba; MultiColumnFeature mcf;
	        Collection theBAs;
	        
		    for (int i = 0; i < theFeatures.size(); i++)
	        {   
	        	f = (Feature) theFeatures.get(i);
	        	if (f instanceof BaseAttribute)
	        	{
	        		ba = (BaseAttribute) f;
		        	theCols.add(ba.getCurrentColumn());  
	        	}
	        	else
	        	{
	        		mcf = (MultiColumnFeature) f;
	        		theBAs = mcf.getBaseAttributes();
	        		Iterator it = theBAs.iterator();
	        		while (it.hasNext())
	        		{
			        	theCols.add(((BaseAttribute) it.next()).getCurrentColumn());
	        		}
	        	}
	        }
	        theCols.trimToSize();
        
    	    Column theTargetAttributeColumn = this.getTheTargetAttribute().getCurrentColumn();

		    mysvm.callSVM(	theColumnSet,
			                theTargetAttributeColumn,
        	                this.getInputConcept().getId(),
            	            this.getC().getValue(),
                	        this.getKernelType().getValue(),
                    	    this.getEpsilon().getValue(),
                        	theCols);

			return mysvm.getXiAlphaEstimation();
		}
   		catch (M4Exception m4e)
   		{   throw new M4CompilerError("M4 interface error in " + this.getName() + ": " + m4e.getMessage());  } 
	}
	
	// adds the indexed element of Vector "from" to Vector "to" if it 
	// is not there yet. Returns a new, extended Vector in that case and null otherwise.
	private Vector addIndexedFeature(Vector to, Vector from, int index)
	{
		Object o = from.get(index);
		
		if (to.contains(o))
		{   
			return null;   
		}
		else
		{   
			Vector v = new Vector();
			for (int i = 0; i < to.size(); i++)
			{   v.add(to.get(i));   }
			v.add(o);
			v.trimToSize();
			
			return v;
		}
	}
	
	// returns a new Vector with all elements of the given one
	// except the indexed element.
	private Vector removeIndexedFeature(Vector from, int index)
	{
		Vector v = new Vector();
		for (int i = 0; i < from.size(); i++)
		{
			if (i != index)
			{  v.add(from.get(i));  }
		}
		v.trimToSize();
		return v;
	}

	/**
	 * Get the parameter "SearchDirection".
	 * 
	 * @return the direction as a Value
	 */
	private Value getDirection() throws M4CompilerError
	{   return (Value) this.getSingleParameter("SearchDirection");   }
	
	/**
	 * Get the parameter "PositiveTargetValue".
	 * 
	 * @return the direction as a Value
	 */
	private Value getPositiveTarget() throws M4CompilerError
	{   return (Value) this.getSingleParameter("PositiveTargetValue");   }

	/**
	 * Get the parameter "SampleSize".
	 * 
	 * @return the direction as a Value
	 */
	private Value getSampleSize() throws M4CompilerError
	{   return (Value) this.getSingleParameter("SampleSize");   }

	/**
	 * Gets the c.

	 * @return Returns a Value
	 */
	private Value getC() throws M4CompilerError {
		return (Value) this.getSingleParameter("C");
	}

	/**
	 * Gets the parameter 'TheKey'.

	 * @return Returns a BaseAttribute
	 */
	private BaseAttribute getKey() throws M4CompilerError {
		return (BaseAttribute) this.getSingleParameter("TheKey");
	}

	/**
	 * Gets the parameter 'TheTargetAttribute'.
	 * 
	 * @return Returns a BaseAttribute
	 */
	private BaseAttribute getTheTargetAttribute() throws M4CompilerError {
		return (BaseAttribute) this.getSingleParameter("TheTargetAttribute");
	}
	
	/**
	 * Gets the epsilon.

	 * @return Returns a Value
	 */
	private Value getEpsilon() throws M4CompilerError {
		return (Value) this.getSingleParameter("Epsilon");
	}

	/**
	 * Gets the kernelType.

	 * @return Returns a Value
	 */
	private Value getKernelType() throws M4CompilerError {
		return (Value) this.getSingleParameter("KernelType");
	}
	
	/**
	 * Reads parameter 'UseDB_SVM'.

	 * @return true iff the parameter was set to true
	 */
	private boolean mustUseDatabaseSVM() throws M4CompilerError {			
		Value v = (Value) this.getSingleParameter("UseDB_SVM");
		boolean b = false;
		if (v != null)
		{	b = v.getValue().equalsIgnoreCase("true");  }
		return b;
	}
}
/*
 * Historie
 * --------
 *
 * $Log: FeatureSelectionWithSVM.java,v $
 * Revision 1.5  2006/04/11 14:10:11  euler
 * Updated license text.
 *
 * Revision 1.4  2006/04/06 16:31:11  euler
 * Prepended license remark.
 *
 * Revision 1.3  2006/03/30 16:07:13  scholz
 * fixed author tags for release
 *
 * Revision 1.2  2006/03/23 11:13:45  euler
 * Improved exception handling.
 *
 * Revision 1.1  2006/01/03 09:54:22  hakenjos
 * Initial version!
 *
 */
