/*
 * MiningMart Version 1.1
 * 
 * Copyright (C) 2006 Martin Scholz, Timm Euler, 
 *                    Daniel Hakenjos, Katharina Morik
 *
 * Contact: miningmart@ls8.cs.uni-dortmund.de
 *
 * A list of contributing developers (other than the copyright 
 * holders) can be found at
 * http://mmart.cs.uni-dortmund.de/downloads/download.html
 * 
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program, see the file MM_HOME/LICENSE; if not, write
 * to the Free Software Foundation, Inc., 51 Franklin Street, Fifth
 * Floor, Boston, MA 02110-1301, USA.
 */
package edu.udo.cs.miningmart.operator;

import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Random;
import java.util.Vector;

import edu.udo.cs.miningmart.exception.M4CompilerError;
import edu.udo.cs.miningmart.exception.M4Exception;
import edu.udo.cs.miningmart.m4.BaseAttribute;
import edu.udo.cs.miningmart.m4.Column;
import edu.udo.cs.miningmart.m4.Columnset;
import edu.udo.cs.miningmart.m4.Value;

/**
 * @author Martin Scholz
 * @version $Id: RowSelectionByUnbiasing.java,v 1.5 2006/09/27 14:59:56 euler Exp $
 */
public class RowSelectionByUnbiasing extends RowSelection {

	// ***** Private fields *****

	private static final String ROW_ID          = "ID";
	private static final String PREDICTED_LABEL = "PL";
	private static final String CORRECT_LABEL   = "LA";

	private int numRowsInputCs;
	private String sqlQuery;
	
	private final HashMap nominalsToIntMappingPredictions = new HashMap();
	private final HashMap nominalsToIntMappingLabel = new HashMap();

	private int      remainingTupels;
	private int[][]  remainingTupelsByPL;

	private final Random random = new Random();
	
	// ***** Helper method associated to the private fields *****
	
	private static int getIdForNomToInt(HashMap hm, String prediction) {
		Integer index = (Integer) hm.get(prediction);
		if (index == null) {
			index =	new Integer(hm.size());
			hm.put(prediction, index);
		}
		return index.intValue();
	}

	private int getIdForPrediction(String prediction) {
		return getIdForNomToInt(this.nominalsToIntMappingPredictions, prediction);
	}

	private int getIdForLabel(String label) {
		return getIdForNomToInt(this.nominalsToIntMappingLabel, label);		
	}

	private boolean decreaseIfTupleUseful(int predicted, int label) {
		boolean useful = this.remainingTupelsByPL[predicted][label] > 0;
		if (useful) {
			this.remainingTupelsByPL[predicted][label] --;
			this.remainingTupels--;
		}
		return useful;
	}


	private String getTmpTableName() {
		return "tmp_" + this.getStep().getId();		
	}







	/**
	 * Overrides the method from RowSelection
	 * 
	 * @see edu.udo.cs.miningmart.m4.core.operator.SingleCSOperator#generateSQLDefinition(String)
	 * @see edu.udo.cs.miningmart.m4.core.operator.RowSelection#generateSQLDefinition(String)
	 */
    public String generateSQLDefinition(String selectPart) throws M4CompilerError {
    	try {
	        String viewDef = "(select " + selectPart +
		                     " from (" + getInputConcept().getCurrentColumnSet().getCompleteSQLQuery(ROW_ID)
		                     + ") AS T1, (SELECT " + ROW_ID + " FROM " + this.getTmpTableName() + ") AS T2"
		                     + " where T1." + ROW_ID + "=T2." + ROW_ID + ")";
		    this.selectRowsForResult();
	        return viewDef;
    	}
   		catch (M4Exception m4e)
   		{   throw new M4CompilerError("M4 interface error in " + this.getName() + ": " + m4e.getMessage());  } 
    }

	// **********************************
	
	/**
	 * This method reads a block of the target size from the input Columnset
	 * and initializes counters. The priors of the class labels are estimated
	 * from their frequency in this &quot;master block&quot;. The number of
	 * occurences of instances in different &quot;prediction subsets&quot; is
	 * kept, just the associated classes are randomly assigned according to the
	 * estimated priors.
	 */
	protected void initByMasterBlock() throws M4CompilerError, SQLException {
		this.remainingTupels = Math.min(this.getHowMany(), this.numRowsInputCs);
		Vector predPriors  = new Vector();
		Vector labelPriors = new Vector();
		for (int i=0; i<this.remainingTupels; i++) {
			int prediction, label;
			{
				int rowId = this.random.nextInt(this.numRowsInputCs);
				int[] predLabel = this.getTuple(rowId);
				prediction = predLabel[0];
				label = predLabel[1];
			}

			{ // Count number of same predictions:
				while (predPriors.size() < prediction + 1) {
					predPriors.add(new Integer(0));
				}
				int current = ((Integer) predPriors.get(prediction)).intValue();
				predPriors.set(prediction, new Integer(current + 1));
			}
			
			{ // Count number of same labels:
				while (labelPriors.size() < label + 1) {
					labelPriors.add(new Integer(0));
				}
				int current = ((Integer) labelPriors.get(label)).intValue();
				labelPriors.set(label, new Integer(current + 1));			
			}
		}

		double priors[] = new double[labelPriors.size()];
		for (int i=0; i<priors.length; i++) {
			priors[i] = ((Integer) labelPriors.get(i)).doubleValue() / this.remainingTupels;
		}
		
		int[][] countArray = new int[predPriors.size()][labelPriors.size()];
		Iterator it = predPriors.iterator();
		for (int row = 0; row < countArray.length; row++) {
			int predCount = ((Integer) it.next()).intValue();
			// int[] pl = new int[priors.length];
			for (int col=0; col<priors.length; col++) {
				countArray[row][col] = (int) (priors[col] * predCount);
			}
		}
		this.remainingTupelsByPL = countArray;
	}

	/**
	 * This method tries to find a set of tuples satisfying the specifications
	 * created by the <code>initByMasterBlock</code> method. A table is created
	 * for the selected row numbers and it is filled with the result of this
	 * sampling step. The number of tuples maximally read can be limited if the
	 * according parameter is set to a positive value. Otherwise the size of the
	 * input Columnset is chosen as the default value. Sampling is done with
	 * replacement.
	 */
    protected void selectRowsForResult() throws M4CompilerError {
    	try {
			Columnset inputCS = this.getInputConcept().getCurrentColumnSet();
			this.numRowsInputCs = Integer.parseInt(inputCS.readOrComputeCount());

			Column predCol  = this.getPredictedLabelBA().getCurrentColumn();
			Column labelCol = this.getLabelBA().getCurrentColumn();

			this.sqlQuery = "SELECT "
						  + predCol.getSQLDefinition()  + " AS " + PREDICTED_LABEL + ", "
			              + labelCol.getSQLDefinition() + " AS " + CORRECT_LABEL
			              + " FROM (" + inputCS.getCompleteSQLQuery(ROW_ID) + ")";

			try {
				this.initByMasterBlock();
				String tmpTableName = this.getTmpTableName();
				String createSql = "CREATE TABLE " + tmpTableName + " ( " + ROW_ID + " NUMBER )";
				this.getM4Db().dropBusinessTable(tmpTableName);
				this.executeBusinessSqlWrite(createSql);
				this.getM4Db().addTableToTrash(tmpTableName, inputCS.getSchema(), this.getStep().getId());
	
				int maxIterations = getMaxTuplesToRead();
				if (maxIterations == 0) {
					maxIterations = this.numRowsInputCs;
				}
				
				String sqlInsertPre  = "INSERT INTO " + tmpTableName + " VALUES ( ";
				String sqlInsertPost = " )";
				for (int i=0; (i<maxIterations && this.remainingTupels > 0); i++) {
					int rowId, pred, label;
					{
						rowId = this.random.nextInt(this.numRowsInputCs);
						int[] predLabel = this.getTuple(rowId);
						pred  = predLabel[0];
						label = predLabel[1];
					}
					boolean useful = this.decreaseIfTupleUseful(pred, label);
					if (useful) {
						String sqlInsert = sqlInsertPre + rowId + sqlInsertPost;
						this.executeBusinessSqlWrite(sqlInsert);
					}
				}
			}
			catch (SQLException e) {
				throw new M4CompilerError("SQLException caught:\n " + e.getMessage());
			}
			this.getM4Db().addTableToTrash(this.getNewCSName(), this.getInputConcept().getCurrentColumnSet().getSchema(), this.getStep().getId());
    	}
   		catch (M4Exception m4e)
   		{   throw new M4CompilerError("M4 interface error in " + this.getName() + ": " + m4e.getMessage());  } 
    }

	/** 
	 * This is a helper method reading a single row from the database and returning
	 * predicted value and label.
	 */
	private int[] getTuple(int rowId) throws M4CompilerError, SQLException {
		String sql = this.sqlQuery + " WHERE " + ROW_ID + " = " + rowId;
		ResultSet rs = null;
		int[] result = null;
		try {
			rs = this.executeBusinessSqlRead(sql);
			if (!rs.next()) {
				 throw new M4CompilerError("RowSelectionByUnbiasing: Cannot access row id " + rowId 
				 						  + "!\nQuery was: " + sql);
			}
			result = new int[2];
			result[0] = this.getIdForPrediction(rs.getString(PREDICTED_LABEL));
			result[1] = this.getIdForLabel(rs.getString(CORRECT_LABEL));	
		}
		finally {
			if (rs != null) {
				rs.close();	
			}
		}
		return result;
	}



	// ***** Getters for parameter and other public methods *****

    // only needed for java compiler...
    public String generateConditionForOp() { 
    	return null;
    }
    
	/**
	 * @return the attribute with the predicted label in case of models,
	 *         or an attribute defining a segment
	 */
    public BaseAttribute getPredictedLabelBA() throws M4CompilerError {
		return (BaseAttribute) this.getSingleParameter("PredictedLabel");
    }

	/**
	 * @return the attribute with a nominal label (or property of interest)
	 */
    public BaseAttribute getLabelBA() throws M4CompilerError {
		return (BaseAttribute) this.getSingleParameter("Label");
    }

	/** @return the number of tuples to be selected */
    public int getHowMany() throws M4CompilerError {
		Value v = (Value) this.getSingleParameter("HowMany");
		try {
			return Integer.parseInt(v.getValue());
		}
		catch (NumberFormatException e) {
			throw new M4CompilerError
			("Parameter 'HowMany' of Operator RowSelectionByUnbiasing does not"
			+ " contain an integer value: " + v.getValue());
		}
		catch (NullPointerException e) {
			throw new M4CompilerError
			("Parameter 'HowMany' of Operator RowSelectionByRowSelectionByUnbiasing not found!");
		}
    }

	/**
	 * @return the number of tuples to be read at most, '0' means
	 *         not to limit the search
	 */
    public int getMaxTuplesToRead() throws M4CompilerError {
		Value v = (Value) this.getSingleParameter("MaxNumOfTuplesRead");
		if (v == null)
			return 0; // 0: unlimited
		try {
			return Math.max(0, Integer.parseInt(v.getValue()));
		}
		catch (NumberFormatException e) {
			throw new M4CompilerError
			("Parameter 'MaxNumOfTuplesRead' of Operator RowSelectionByUnbiasing does not"
			+ " contain an integer value: " + v.getValue());
		}
    }
    
}
/*
 * Historie
 * --------
 * 
 * $Log: RowSelectionByUnbiasing.java,v $
 * Revision 1.5  2006/09/27 14:59:56  euler
 * New version 1.1
 *
 * Revision 1.4  2006/04/11 14:10:11  euler
 * Updated license text.
 *
 * Revision 1.3  2006/04/06 16:31:11  euler
 * Prepended license remark.
 *
 * Revision 1.2  2006/03/23 11:13:44  euler
 * Improved exception handling.
 *
 * Revision 1.1  2006/01/03 09:54:21  hakenjos
 * Initial version!
 *
 */
