package edu.udo.cs.mySVMdb.Container;
import java.sql.*;
//import oracle.sql.*;
//import oracle.jdbc.driver.*;
import java.util.Vector;

public class JDBCDatabaseContainer
{
    /**
     * Implementation of a container in a database
     * @author Stefan Rping
     * @version 1.0
     */

    protected int dim;
    protected int train_size;
    protected int test_size;
    public double[] Exp;
    public double[] Dev;
    String select_text;

    /** 
     * Name of the examples 
     */
    protected String examples_name;

    /** 
     * Name of the test examples 
     */
    protected String test_examples_name;

    /** 
     * Name of the model
     */
    protected String model_name;

    /** 
     * Name of the parameters
     */
    protected String parameters_name;

    /**
     * Name of the predictions
     */
    protected String predictions_name;

    /**
     * JDBC connection object
     */
    protected Connection conn;    

    /**
     * Database URL
     */
    protected String db_url = "jdbc:oracle:oci8:@some_instance";

    /**
     * Database login
     */
    protected String db_login = "scott";

    /**
     * Database password
     */
    protected String db_password = "tiger";

    /**
     * mapping int -> key
     */
    public String[] keys;

    /**
     * mapping int -> key
     */
    protected String[] test_keys;

    /** 
     * Name of the column containing the examples key.
     */
    protected String key_column;

    /** 
     * Names of the columns containing the examples attributes to be used in SQL query.
     */
    protected String x_columns;
    public String[] x_column;

    /** 
     * Name of the column containing the examples y values.
     */
    protected String y_column;


    /**
     * Entry in y_column that is the positive class
     */
    protected String target_concept;

    /** 
     * Name of the column containing the predicted y values.
     * Usually equals y_column
     */
    protected String pred_column = "Y";

    double[] alphas;
    double[] ys;
    double b;

    /**
     * Prepared Statement for get_example
     */
    PreparedStatement get_example_statement;

    /**
     * Prepared Statement for get_param
     */
    PreparedStatement get_param_statement;

    /**
     * Class constructor with default parameters. Just calls super.
     */
    public JDBCDatabaseContainer(){
    };

    /**
     * Class constructor with parameters. Just calls super.
     * @param params Array of parameters in the form "name:value".
     * @exception Exception on unknown parameter
     */
    public JDBCDatabaseContainer(String[] params)
	throws Exception
    {
	init(params);
    };


    /**
     * Initialize container access
     * @param params Names of the parameter table and connection info
     * @exception Exception on unknown parameter
     */
    public void init(String[] params)
	throws Exception
    {
	// System.out.println("init("+params[0]+") called");
	int pos;
	String param;
	String value;
	String parameter_table=null;
	// get connection info
	if(params != null){
	    for(int i=0;i<params.length;i++){
		if(params[i] != null){
		    // null values being ignored
		    pos = (params[i]).indexOf(":");
		    if(pos >= 0){
			param = ((params[i]).substring(0,pos)).toLowerCase();
			value = (params[i]).substring(pos+1);
			if(param.equals("db_url")){
			    db_url = value;
			}
			else if(param.equals("db_login")){
			    db_login = value;
			}
			else if(param.equals("db_password")){
			    db_password = value;
			}
			else{
			    parameter_table = params[i];
			};
		    }
		    else{
			parameter_table = params[i];
		    };
		};
	    };
	};

	// open connection
	open();

	// read params from parameter_table
	read_params(parameter_table);
    };


    /**
     * Registers the oracle driver and opens the database connection.
     * @exception Exception On database connection error
     */
    protected void open()
	throws Exception
    {
	DriverManager.registerDriver(new oracle.jdbc.driver.OracleDriver());
	conn = DriverManager.getConnection(db_url,db_login,db_password);
	conn.setAutoCommit(false);
    };


    /**
     * Reads model from database
     * @exception Exception on SQL error
     */
    public void read_model()
	throws Exception
    {
	read_trainset_meta_info();

	Statement stmt = conn.createStatement();
	ResultSet rset;
	rset = stmt.executeQuery("select count(*) from "+model_name);
	rset.next();
	train_size = rset.getInt(1);
	keys = new String[train_size];
	alphas = new double[train_size];
	rset.close();
	rset = stmt.executeQuery("select key, alpha from "+model_name+" where key is not null");
	int i=0;
	while(rset.next()){
	    keys[i] = rset.getString(1);
	    alphas[i] = rset.getDouble(2);
	    i++;
	};
	rset.close();
	rset = stmt.executeQuery("select alpha from "+model_name+" where key is null");
	if(rset.next()){
	    b = rset.getDouble(1);
	}
	else{
	    // ???
	    b = 0;
	};

	get_example_statement = conn.prepareStatement("select "+x_columns+" from "+examples_name+" where "+key_column+" = ?");

	rset.close();
	stmt.close();
    };


    /**
     * Generates and writes prediction of examples in table test_examples_name
     * @exception Exception on SQL error
     */
    public void write_prediction()
	throws Exception
    {
	// prediction can go into new table or row of test_examples_name


	// check if table or row exist


	// create prepared statement to read

	// cread batch statement to write


	// predict


	//	stmt.close();
    };


    /**
     * Writes model to database
     * @exception Exception on SQL error
     */
    public void write_model()
	throws Exception
    {
	Statement delstmt = conn.createStatement();
	// drop temporary table and create persistent table
	delstmt.executeQuery("drop table "+model_name);
	delstmt.executeQuery("create table "+model_name+" (key varchar(60), alpha number)");
	//delstmt.close();
	PreparedStatement stmt = conn.prepareStatement("insert into "+model_name+" values (?,?)");
	stmt.clearBatch();
	int i;
	for(i=0;i<train_size;i++){
	    if(alphas[i] != 0.0){
		stmt.setString(1,keys[i]);
		stmt.setDouble(2,alphas[i]);
		stmt.addBatch();
	    };
	};
	stmt.setNull(1,java.sql.Types.VARCHAR);
	stmt.setDouble(2,b);
	stmt.addBatch();
	stmt.executeBatch();

	// create view for prediction
	String viewname = get_param("view_name");
	if(viewname != null){
	    try{
		delstmt.executeQuery(
		   "create view "+viewname+" as select x."+key_column
		   +" as key, b.alpha + (select sum(z.alpha * ("+select_text
		   +")) from "+model_name+" z, "+examples_name
		   +" y where z.key = y."+key_column+") as svm_pred from "
		   +model_name+" b, "+examples_name+" x where b.key is null");
	    }
	    catch(Exception e){
		System.out.println("ERROR: could not create view "+viewname+", error message: "+e.getMessage());
	    };
	};
	stmt.close();
	delstmt.close();
    };


    /**
     * Writes all cached data to database and closes the database connection.
     * @exception Exception on SQL error
     */
    public void close()
	throws Exception
    {
	if(get_example_statement != null){
	    get_example_statement.close();
	};
	if(get_param_statement != null){
	    get_param_statement.close();
	};
	conn.commit();
	conn.close();
    };


    /**
     * Read parameters from given location.
     * @exception Exception if an error opening the container occurred
     */
    public void read_params(String location)
	throws Exception
    {
	// no reading actually done
	parameters_name = location;

	get_param_statement = conn.prepareStatement("select value from " + parameters_name + " where parameter = ?");

	// dummy call to check if table exists
	get_param("model_name");
    };


    /**
     * Prepare data structures for learning
     * @exception Exception if an error occured
     */
    public void init_for_learning()
	throws Exception
    {
	// model table must exist
	model_name = get_param("model_name");
	if(model_name == null){
	    model_name = "svm_model";
	};
	Statement stmt = conn.createStatement();
	try{
	    // test if table exists
	    stmt.executeQuery("select count(*) from "+model_name);
	}
	catch(SQLException e){
	    // table does not exist, create one
	    stmt.clearBatch();
	    stmt.addBatch("create global temporary table "+model_name+" (key varchar(60), alpha number)");
	    // stmt.addBatch("create table "+model_name+" (key varchar(60), alpha number)");
	    //	    stmt.addBatch("create unique index "+model_name+"_key_index on "+model_name+"(key)");
	    stmt.executeBatch();
	};
	stmt.close();

	// read examples
	examples_name = get_param("trainset");
	read_trainset();
    };


    public void read_trainset_meta_info()
	throws Exception
    {
	// read meta information
	Statement stmt = conn.createStatement();
	ResultSet rset;
	int i;

	key_column = get_param("key_column");
	if(key_column == null){
	    key_column = "ROWID";
	};
	y_column = get_param("y_column");
	if(y_column == null){
	    y_column = "Y";
	};

	rset = stmt.executeQuery("select count(*) from "+parameters_name+" where parameter = 'x_column'");
	rset.next();
	dim = rset.getInt(1);
	rset.close();
	if(dim > 0){
	    // read x_columns from parameters
	    x_columns = "";
	    x_column = new String[dim];
	    rset = stmt.executeQuery("select value from "+parameters_name+" where parameter = 'x_column'");
	    i=0;
	    while(rset.next()){
		x_column[i] = rset.getString(1);
		x_columns += x_column[i] + ",";
		i++;
	    };
	    rset.close();
	    x_columns = x_columns = x_columns.substring(0,x_columns.length()-1);
	}
	else{
	    // read x_columns from data table
	    rset = stmt.executeQuery("select * from "+examples_name);
	    ResultSetMetaData meta = rset.getMetaData();
	    int mydim = meta.getColumnCount();
	    if(mydim <= 0){
		throw(new Exception("No columns found in "+examples_name));
	    };
	    String name;
	    
	    // read x metadata
	    x_columns = "";
	    x_column = new String[mydim];
	    dim = 0;  // mydim can contain y and id
	    for(i=1;i<=mydim;i++){
		name = meta.getColumnName(i);
		if((! name.equals(y_column)) 
		   && (! name.equals(key_column))
		   // JDBC types that can be castet to a double:
		   && ((meta.getColumnType(i) == java.sql.Types.TINYINT)
		       || (meta.getColumnType(i) == java.sql.Types.SMALLINT)
		       || (meta.getColumnType(i) == java.sql.Types.INTEGER)
		       || (meta.getColumnType(i) == java.sql.Types.BIGINT)
		       || (meta.getColumnType(i) == java.sql.Types.REAL)
		       || (meta.getColumnType(i) == java.sql.Types.FLOAT)
		       || (meta.getColumnType(i) == java.sql.Types.DOUBLE)
		       || (meta.getColumnType(i) == java.sql.Types.DECIMAL)
		       || (meta.getColumnType(i) == java.sql.Types.NUMERIC)
		       || (meta.getColumnType(i) == java.sql.Types.BIT))
		   ){
		    x_columns += name + ",";
		    x_column[dim] = name;
		    dim++;
		};
	    };
	    if(dim == 0){
		throw(new Exception("No x-columns found in "+examples_name));
	    };
	    if(dim != mydim){
		String[] new_x_column = new String[dim];
		for(i=0;i<dim;i++){
		    new_x_column[i] = x_column[i];
		};
		x_column = new_x_column;
	    };
	    x_columns = x_columns.substring(0,x_columns.length()-1);
	    rset.close();
	};

	// scaling
	boolean do_scale_x = true;
	boolean do_scale_y = true;
	if(get_param("svm_type") != null){
	    if(get_param("svm_type").equals("pattern")){
		do_scale_y = false;
	    };
	};
	if(get_param("scale") != null ){ 
	    if(get_param("scale").equals("x")){
		do_scale_x = true;
		do_scale_y = false;
	    }
	    else if(get_param("scale").equals("xy")){
		do_scale_x = true;
		do_scale_y = true;
	    };
	};
	if(do_scale_x || do_scale_y){
	    Exp = new double[dim+1];
	    Dev = new double[dim+1];
	    String the_query = "select ";
	    for(i=0;i<dim;i++){
		the_query += "AVG("+x_column[i]+"), STDDEV("+x_column[i]+"), ";
	    };
	    the_query += "AVG("+y_column+"), STDDEV("+y_column+")";
	    the_query += "from "+examples_name;
	    rset = stmt.executeQuery(the_query);
	    rset.next();
	    for(i=0;i<=dim;i++){
		Exp[i] = rset.getDouble(2*i+1);
		Dev[i] = rset.getDouble(2*i+2);
	    };
	    if(! do_scale_y){
		Exp[dim] = 0;
		Dev[dim] = 1; 
	    };
	    rset.close();
	}
	else{
	    Exp = null;
	    Dev = null;
	};
	stmt.close();
    };


    /**
     * Read trainset from given location.
     * @exception Exception if an error opening the container occurred
     */
    public void read_trainset()
	throws Exception
    {
	read_trainset_meta_info();

	Statement stmt = conn.createStatement();
	ResultSet rset;
	int i;
	// read keys
	// if parameter is set, keys are read from table
	keys = null;
	if(get_param("read_keys_from_model") != null){
	    if(get_param("read_keys_from_model").equals("true")){
		// read keys from model
		rset = stmt.executeQuery("select count(*) from "+model_name);
		rset.next();
		train_size = rset.getInt(1);
		keys = new String[train_size];
		rset = stmt.executeQuery("select key from "+model_name);
		i=0;
		while(rset.next()){
		    keys[i] = rset.getString(1);
		    i++;
		};
		rset.close();
	    };
	};
	if(keys == null){
	    // read keys from trainset
	    rset = stmt.executeQuery("select count(*) from "+examples_name);
	    rset.next();
	    train_size = rset.getInt(1);
	    rset.close();
	    keys = new String[train_size];
	    //	    rset = stmt.executeQuery("select "+key_column+", "+y_column+" from "+examples_name);
	    rset = stmt.executeQuery("select "+key_column+" from "+examples_name);
	    i=0;
	    while(rset.next()){
		keys[i] = rset.getString(1);
		i++;
	    };
	    rset.close();
	    stmt.clearBatch();
	    stmt.addBatch("delete from "+model_name);
	    stmt.addBatch("insert into "+model_name+" (key, alpha) select "+key_column+", null from "+examples_name);
	    // statistics!
	    //	    stmt.addBatch("analyze table "+model_name+" compute statistics");
	    stmt.executeBatch();
	};
	// init alpha and y
	alphas = new double[train_size];
	ys = new double[train_size];
	// read ys
	PreparedStatement y_stmt = conn.prepareStatement("select "+y_column+" from "+examples_name+" where "+key_column+" = ?");
	target_concept = get_param("target_concept");
	if(target_concept != null){
	    // translation of y to {-1,1}
	    for(i=0;i<train_size;i++){
		alphas[i] = 0.0;
		y_stmt.setString(1,keys[i]);
		rset = y_stmt.executeQuery();
		rset.next();
		if(target_concept.equals(rset.getString(1))){
		    ys[i] = 1;
		}
		else{
		    ys[i] = -1;
		};
		rset.close();
	    };
	}
	else if((Dev[dim] == 0) || 
		((Exp[dim] == 0) && (Dev[dim] == 1))){
		    // read without scaling
		    for(i=0;i<train_size;i++){
			alphas[i] = 0.0;
			y_stmt.setString(1,keys[i]);
			rset = y_stmt.executeQuery();
			rset.next();
			ys[i] = rset.getDouble(1);
			rset.close();
		    };
		}
	else{
	    // read and scale
	    for(i=0;i<train_size;i++){
		alphas[i] = 0.0;
		y_stmt.setString(1,keys[i]);
		rset = y_stmt.executeQuery();
		rset.next();
		ys[i] = (rset.getDouble(1)-Exp[dim])/Dev[dim];
		rset.close();
	    };
	};
	y_stmt.close();

	// statement for examples
	get_example_statement = conn.prepareStatement("select "+x_columns+" from "+examples_name+" where "+key_column+" = ?");

	stmt.close();
    };


    /**
     * Read testset from given location.
     * @exception Exception if an error opening the container occurred
     */
    public void read_testset(String location)
	throws Exception
    {
	test_examples_name = location;
	// read meta information (table format is supposed to be equal to trainset)
	Statement stmt = conn.createStatement();
	stmt.setFetchSize(1); // now only one row needed for meta data
	ResultSet rset;
	rset = stmt.executeQuery("select count(*) from "+test_examples_name);
	rset.next();
	test_size = rset.getInt(1);
	rset.close();
	test_keys = new String[test_size];
	rset = stmt.executeQuery("select "+key_column+" from "+test_examples_name);
	int i=0;
	while(rset.next()){
	    test_keys[i] = rset.getString(1);
	    i++;
	};
	rset.close();
	// read data

	stmt.close();
    };


    /**
     * Read model from given location.
     * @exception Exception if an error opening the container occurred
     */
    public void read_model(String location)
	throws Exception
    {
	model_name = location;
	// check if table exists
    };


    /**
     * creates a temporary table with given name as prefix
     */
    protected String create_temp_table(String name, String definition)
	throws Exception
    {
	Statement stmt = conn.createStatement();
	stmt.executeUpdate("create table "+name+" ("+definition+")");
	stmt.close();

	return name;
    };


    /**
     * Counts the training examples.
     * @return Number of examples
     */
    public int count_examples()
    {
	return train_size;
    };


    /**
     * Counts the positive training examples
     * @return Number of positive examples
     */
    public int count_pos_examples()
	throws Exception
    {
	Statement stmt = conn.createStatement();
	ResultSet rset = stmt.executeQuery("select * from "+examples_name);
	rset = stmt.executeQuery("select count(*) from "+examples_name+" where "+y_column+" > 0");
	rset.next();

	int result = rset.getInt(1);

	rset.close();
	stmt.close();


	return result;
    };


    /**
     * Gets the dimension of the examples
     * @return dim
     */
    public int get_dim()
    {
	return dim;
    };


    /**
     * Counts the test examples
     * @return Number of test examples
     */
    public int count_test_examples()
    {
	return test_size;
    };


    /**
     * Gets an example.
     * @param pos Number of example
     * @return Array of example attributes in their default order
     */
    public double[] get_example(int pos)
	throws Exception
    {
	double[] x = new double[dim];
	get_example_statement.setString(1,keys[pos]);
	ResultSet rset = get_example_statement.executeQuery();
	int i;
	rset.next();
	if(Exp != null){
	    for(i=0;i<dim;i++){
		x[i] = (rset.getDouble(i+1)-Exp[i])/Dev[i]; // columns start at 1
	    };
	}
	else{
	    for(i=0;i<dim;i++){
		x[i] = rset.getDouble(i+1); // columns start at 1
	    };
	};
	rset.close();
	return x;
    };


    /**
     * Gets a test example.
     * @param pos Number of example
     * @return Array of example attributes in their default order
     */
    public double[] get_test_example(int pos)
	throws Exception
    {
	double[] x = new double[dim];
	String key = test_keys[pos];
	Statement stmt = conn.createStatement();
	ResultSet rset = stmt.executeQuery("select "+x_columns+" from "+test_examples_name+" where "+key_column+" = '"+key+"'");
	int i;
	rset.next();
	if(Exp != null){
	    for(i=0;i<dim;i++){
		x[i] = (rset.getDouble(i+1)-Exp[i])/Dev[i]; // columns start at 1
	    };
	}
	else{
	    for(i=0;i<dim;i++){
		x[i] = rset.getDouble(i+1);
	    };
	};
	rset.close();
	stmt.close();
	return x;
    };


    /**
     * Gets an y-value.
     * @param pos Number of example
     * @return y
     */
    public double get_y(int pos)
    {
	return ys[pos];
    };


    /**
     * Gets the y array
     * @return y
     */
    public double[] get_ys()
    {
	return ys;
    };


    /**
     * Gets an alpha-value.
     * @param pos Number of example
     * @return alpha
     */
    public double get_alpha(int pos)
    {
	return alphas[pos];
    };


    /**
     * Gets the alpha array
     * @return alpha
     */
    public double[] get_alphas()
    {
	return alphas;
    };


    /**
     * swap two training examples
     * @param pos1
     * @param pos2
     */
    public void swap(int pos1, int pos2)
    {
	String dummy = keys[pos1];
	keys[pos1] = keys[pos2];
	keys[pos2] = dummy;
	double dummyd = alphas[pos1];
	alphas[pos1] = alphas[pos2];
	alphas[pos2] = dummyd;
	dummyd = ys[pos1];
	ys[pos1] = ys[pos2];
	ys[pos2] = dummyd;
    };


    /**
     * get b
     * @return b
     */
    public double get_b()
    {
	return b;
    };


    /**
     * set b
     * @param b
     */
    public void set_b(double new_b)
	throws Exception
    {
	b = new_b;
	Statement stmt = conn.createStatement();
	try{
	    int res =  stmt.executeUpdate("update "+model_name+" set alpha = "+b+" where key is null");
	    if(res == 0){
		stmt.executeUpdate("insert into "+model_name+" values (NULL, "+b+")");
	    };
	}
	catch(SQLException e){
	    stmt.executeUpdate("insert into "+model_name+" values (NULL, "+b+")");
	};
	stmt.close();

    };


    /**
     * sets a test y value.
     * @param pos Number of example
     * @param y New value
     */
    public void set_test_y(int pos, double y)
	throws Exception
    {
	Statement stmt = conn.createStatement();
	if(target_concept != null){
	    if(y > 0){
		stmt.executeUpdate("insert into "+predictions_name+"+ values ("+test_keys[pos]+", "+target_concept+")");
	    }
	    else{
		stmt.executeUpdate("insert into "+predictions_name+"+ values ("+test_keys[pos]+", 0)");
	    };
	}
	else{
	    stmt.executeUpdate("insert into "+predictions_name+"+ values ("+test_keys[pos]+", "+y+")");
	};
	stmt.close();
    };


    /**
     * sets an alpha value, sets has_alpha too.
     * @param pos Number of example
     * @param alpha New value
     */
    public void set_alpha(int pos, double alpha)
    {
	alphas[pos] = alpha;
    };


    /**
     * sets a kernel row
     * @param i example index
     * @param K_row Array of new value K(example i, example j) for all j
     */
    public void set_K_row(int i, double[] K_row)
    {

    };
   

    /**
     * checks if alphas are initialised
     * @return has_alphas
     */
    public boolean initialised_alpha()
    {
	return false;
    };


    /**
     * Get parameter value
     * @exception Exception on SQL error
     */
    public String get_param(String param)
	throws Exception
    {
	String result=null;
	get_param_statement.setString(1,param);
	ResultSet rset = get_param_statement.executeQuery();
	if(rset.next()){
	    result = rset.getString(1);
	};
	rset.close();
	return result;
    };


    public void set_select_text(String new_select_text)
    {
	select_text = new_select_text;
    };


    public PreparedStatement prepareKijStatement()
	throws Exception
    {
	PreparedStatement ps = conn.prepareStatement(
	    "select "+select_text+" as K"
	    +" from "+examples_name+" x, "
	    +examples_name+" y"
	    +" where x."+key_column+" = ? and y."+key_column+" = ?");
	return ps;
    };


    public PreparedStatement prepareKiStatement()
	throws Exception
    {
	PreparedStatement ps = conn.prepareStatement(
	    "select /*+ ORDERED ALL_ROWS */ "+select_text
	    +", y."+key_column
	    +" from "
	    +examples_name+" x, "
	    +examples_name+" y,"
	    +model_name+" t"
	    +" where x."+key_column+" = ?"
	    +" and t.key = y."+key_column,ResultSet.TYPE_FORWARD_ONLY,ResultSet.CONCUR_READ_ONLY);
	return ps;
    };


    public PreparedStatement prepareKisStatement()
	throws Exception
    {
	int working_set_size;
	try{
	  working_set_size = (new Integer(get_param("working_set_size"))).intValue();
	}
	catch(Exception e){
	    working_set_size = 10; // has to be identical to SVM::init
	};

	String pstext = 
	    "select /*+ ORDERED ALL_ROWS */ "+select_text
	    +", x."+key_column
	    +", y."+key_column+" from "
	    +model_name+" t,"
	    +examples_name+" y,"
	    +examples_name+" x where y."+key_column
	    +" = t.key and x."+key_column+" in (?";
	for(int i=1;i<working_set_size;i++){
	    pstext += ",?";
	};
	pstext += ") order by x."+key_column;
	PreparedStatement ps = conn.prepareStatement(pstext,ResultSet.TYPE_FORWARD_ONLY,ResultSet.CONCUR_READ_ONLY);
	return ps;
    };


    public void shrink(int from, int to)
	throws Exception
    {
	Statement stmt = conn.createStatement();
	stmt.executeQuery("drop table "+model_name);
	stmt.executeQuery("create global temporary table "+model_name+" (key varchar(60), alpha number)");
	stmt.close();


	PreparedStatement pstmt = conn.prepareStatement("insert into "+model_name+" values (?,null)");
	pstmt.clearBatch();
	int i;
	for(i=0;i<to;i++){
	    pstmt.setString(1,keys[i]);
	    pstmt.addBatch();
	};
	pstmt.executeBatch();
	pstmt.close();	
    };
};
