/*
 *  RapidMiner
 *
 *  Copyright (C) 2001-2010 by Rapid-I and the contributors
 *
 *  Complete list of developers available at our web site:
 *
 *       http://rapid-i.com
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU Affero General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU Affero General Public License for more details.
 *
 *  You should have received a copy of the GNU Affero General Public License
 *  along with this program.  If not, see http://www.gnu.org/licenses/.
 */
package com.rapidminer.operator.meta;

import java.util.List;

import Jama.Matrix;

import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;

/**
 * This operator finds the optimal values for a set of parameters using a
 * quadratic interaction model. The parameter <var>parameters</var> is a list
 * of key value pairs where the keys are of the form
 * <code>OperatorName.parameter_name</code> and the value is a comma
 * separated list of values (as for the GridParameterOptimization operator). <br/> 
 * The operator returns an optimal
 * {@link ParameterSet} which can as well be written to a file with a
 * {@link com.rapidminer.operator.io.ParameterSetLoader}. This parameter set
 * can be read in another process using an
 * {@link com.rapidminer.operator.io.ParameterSetLoader}. <br/> The file
 * format of the parameter set file is straightforward and can also easily be
 * generated by external applications. Each line is of the form 
 * <center><code>operator_name.parameter_name = value</code></center>.
 * 
 * @author Stefan Rueping, Helge Homburg
 */
public class QuadraticParameterOptimizationOperator extends GridSearchParameterOptimizationOperator {


	/** The parameter name for &quot;What to do if range is exceeded.&quot; */
	public static final String PARAMETER_IF_EXCEEDS_REGION = "if_exceeds_region";

	/** The parameter name for &quot;What to do if range is exceeded.&quot; */
	public static final String PARAMETER_IF_EXCEEDS_RANGE = "if_exceeds_range";
	private static final String[] EXCEED_BEHAVIORS = { "ignore", "clip", "fail" };

	private static final int IGNORE = 0;

	private static final int CLIP = 1;

	private static final int FAIL = 2;

	private ParameterSet best;

    
	public QuadraticParameterOptimizationOperator(OperatorDescription description) {
		super(description);
	}

    @Override
	public double getCurrentBestPerformance() {
        if (best != null) {
            return best.getPerformance().getMainCriterion().getAverage();
        } else {
            return Double.NaN;
        }
    }

	@Override
	public void doWork() throws OperatorException {
		getParametersToOptimize();

		int ifExceedsRegion = getParameterAsInt(PARAMETER_IF_EXCEEDS_REGION);
		int ifExceedsRange = getParameterAsInt(PARAMETER_IF_EXCEEDS_RANGE);

		// sort parameter values
		String[] valuesToSort;
		String s;
		double val1;
		double val2;
		int ind1;
		int ind2;
		for (int index = 0; index < numberOfParameters; index++) {
			valuesToSort = values[index];
			// straight-insertion-sort of valuesToSort
			for (ind1 = 0; ind1 < valuesToSort.length; ind1++) {
				val1 = Double.parseDouble(valuesToSort[ind1]);
				for (ind2 = ind1 + 1; ind2 < valuesToSort.length; ind2++) {
					val2 = Double.parseDouble(valuesToSort[ind2]);
					if (val1 > val2) {
						s = valuesToSort[ind1];
						valuesToSort[ind1] = valuesToSort[ind2];
						valuesToSort[ind2] = s;
						val1 = val2;
					};
				};
			};
		};
		
		int[] bestIndex = new int[numberOfParameters];
		ParameterSet[] allParameters = new ParameterSet[numberOfCombinations];
		int paramIndex = 0;
		// Test all parameter combinations

		best = null;
		while (true) {
			getLogger().fine("Using parameter set");
			// set all parameter values
			for (int j = 0; j < operators.length; j++) {
				operators[j].getParameters().setParameter(parameters[j], values[j][currentIndex[j]]);
				getLogger().fine(operators[j] + "." + parameters[j] + " = " + values[j][currentIndex[j]]);
			}
			
			PerformanceVector performance = getPerformance(true);

			String[] currentValues = new String[parameters.length];
			for (int j = 0; j < parameters.length; j++) {
				currentValues[j] = values[j][currentIndex[j]];
			};
			allParameters[paramIndex] = new ParameterSet(operators, parameters, currentValues, performance);

			if ((best == null) || (performance.compareTo(best.getPerformance()) > 0)) {
				best = allParameters[paramIndex];
				// bestIndex = currentIndex;
				for (int j = 0; j < numberOfParameters; j++) {
					bestIndex[j] = currentIndex[j];
				};
			};

			// next parameter values
			int k = 0;
			boolean ok = true;
			while (!(++currentIndex[k] < values[k].length)) {
				currentIndex[k] = 0;
				k++;
				if (k >= currentIndex.length) {
					ok = false;
					break;
				}
			}
			if (!ok)
				break;

			paramIndex++;
		};

		// start quadratic optimization
		int nrParameters = 0;
		for (int i = 0; i < numberOfParameters; i++) {

			if ((values[i]).length > 2) {
				log("Param " + i + ", bestI = " + bestIndex[i]);
				nrParameters++;
				if (bestIndex[i] == 0) {
					bestIndex[i]++;
				};
				if (bestIndex[i] == (values[i]).length - 1) {
					bestIndex[i]--;
				};
			} else {
				getLogger().warning("Parameter " + parameters[i] + " has less than 3 values, skipped.");
			};
		};

		if (nrParameters > 3) {
			getLogger().warning("Optimization not recommended for more than 3 values. Check results carefully!");
		};

		if (nrParameters > 0) {
			// Designmatrix A fuer den 3^nrParameters-Plan aufstellen,
			// A*x=y loesen lassen
			// x = neue Parameter
			// check, ob neuen Parameter in zulaessigem Bereich
			// - Okay, wenn in Kubus von 3^k-Plan
			// - Warnung wenn in gegebenem Parameter-Bereich
			// - Fehler sonst
			int threetok = 1;
			for (int i = 0; i < nrParameters; i++) {
				threetok *= 3;
			};

			log("Optimising " + nrParameters + " parameters");

			Matrix designMatrix = new Matrix(threetok, nrParameters + nrParameters * (nrParameters + 1) / 2 + 1);
			Matrix y = new Matrix(threetok, 1);

			paramIndex = 0;
			for (int i = numberOfParameters - 1; i >= 0; i--) {
				if ((values[i]).length > 2) {
					currentIndex[i] = bestIndex[i] - 1;
				} else {
					currentIndex[i] = bestIndex[i];
				};
				paramIndex = paramIndex * (values[i]).length + currentIndex[i];
			};

			int row = 0;
			int c;
			while (row < designMatrix.getRowDimension()) {
				y.set(row, 0, allParameters[paramIndex].getPerformance().getMainCriterion().getFitness()); // Performance Zahl?

				designMatrix.set(row, 0, 1.0);
				c = 1;
				// compute A
				for (int i = 0; i < nrParameters; i++) {
					if ((values[i]).length > 2) {
						designMatrix.set(row, c, Double.parseDouble(values[i][currentIndex[i]]));
						c++;
					};
				};
				// compute C
				for (int i = 0; i < nrParameters; i++) {
					if ((values[i]).length > 2) {
						for (int j = i + 1; j < nrParameters; j++) {
							if ((values[j]).length > 2) {
								designMatrix.set(row, c, Double.parseDouble(values[i][currentIndex[i]]) * Double.parseDouble(values[j][currentIndex[j]]));
								c++;
							};
						};
					};
				};
				// compute Q:
				for (int i = 0; i < nrParameters; i++) {
					if ((values[i]).length > 2) {
						designMatrix.set(row, c, Double.parseDouble(values[i][currentIndex[i]]) * Double.parseDouble(values[i][currentIndex[i]]));
						c++;
					};
				};

				// update currentIndex and paramIndex
				int k = 0;
				c = 1;
				while (k < numberOfParameters) {
					if ((values[k]).length > 2) {
						currentIndex[k]++;
						paramIndex += c;
						if (currentIndex[k] > bestIndex[k] + 1) {
							currentIndex[k] = bestIndex[k] - 1;
							paramIndex -= 3 * c;
							c *= values[k].length;
							k++;
						} else {
							break;
						};
					} else {
						c *= values[k].length;
						k++;
					};
				};
				row++;
			};

			// compute Designmatrix
			Matrix beta = designMatrix.solve(y);
			for (int i = 0; i < designMatrix.getColumnDimension(); i++) {
				logWarning(" -- Writing " + beta.get(i, 0) + " at position " + i + " in vector b");
			}
			// generate Matrix P~
			Matrix p = new Matrix(nrParameters, nrParameters);
			int betapos = nrParameters + 1;
			for (int j = 0; j < (nrParameters - 1); j++) {
				for (int i = 1 + j; i < nrParameters; i++) {
					p.set(i, j, (beta.get(betapos, 0) * 0.5));
					p.set(j, i, (beta.get(betapos, 0) * 0.5));
					betapos++;
				}
			}
			for (int i = 0; i < nrParameters; i++) {
				p.set(i, i, beta.get(betapos, 0));
				betapos++;
			}
			// generate Matrix y~
			Matrix y2 = new Matrix(nrParameters, 1);
			for (int i = 0; i < nrParameters; i++) {
				y2.set(i, 0, beta.get(i + 1, 0));
			}
			y2 = y2.times(-0.5);
			// get stationary point x
			Matrix x = new Matrix(nrParameters, 1);
			try {
				x = p.solve(y2);
			} catch (RuntimeException e) {
				logWarning("Quadratic optimization failed. (invalid matrix)");
			}

			String[] Qvalues = new String[numberOfParameters];
			int pc = 0;
			boolean ok = true;
			for (int j = 0; j < numberOfParameters; j++) {
				if ((values[j]).length > 2) {
					if ((x.get(pc, 0) > Double.parseDouble(values[j][bestIndex[j] + 1])) || (x.get(pc, 0) < Double.parseDouble(values[j][bestIndex[j] - 1]))) {
						logWarning("Parameter " + parameters[j] + " exceeds region of interest (" + x.get(pc, 0) + ")");
						if (ifExceedsRegion == CLIP) {
							// clip to bound
							if (x.get(pc, 0) > Double.parseDouble(values[j][bestIndex[j] + 1])) {
								x.set(pc, 0, Double.parseDouble(values[j][bestIndex[j] + 1]));
							} else {
								x.set(pc, 0, Double.parseDouble(values[j][bestIndex[j] - 1]));
							};
						} else if (ifExceedsRegion == FAIL) {
							ok = false;
						};
					};
					if ((x.get(pc, 0) < Double.parseDouble(values[j][0])) || (x.get(pc, 0) > Double.parseDouble(values[j][values[j].length - 1]))) {
						logWarning("Parameter " + parameters[j] + " exceeds range (" + x.get(pc, 0) + ")");
						if (ifExceedsRange == IGNORE) {
							// ignore error
							logWarning("  but no measures taken. Check parameters manually!");
						} else if (ifExceedsRange == CLIP) {
							// clip to bound
							if (x.get(pc, 0) > Double.parseDouble(values[j][0])) {
								x.set(pc, 0, Double.parseDouble(values[j][0]));
							} else {
								x.set(pc, 0, Double.parseDouble(values[j][values[j].length - 1]));
							};
						} else {
							ok = false;
						};
					};

					Qvalues[j] = x.get(pc, 0) + "";
					pc++;
					// Werte im richtigen Bereich?
				} else {
					Qvalues[j] = values[j][bestIndex[j]];
				};
			};

			getLogger().info("Optimised parameter set:");
			for (int j = 0; j < operators.length; j++) {
				operators[j].getParameters().setParameter(parameters[j], Qvalues[j]);
				getLogger().info("  " + operators[j] + "." + parameters[j] + " = " + Qvalues[j]);
			}
			if (ok) {				
				PerformanceVector qPerformance = super.getPerformance(true);
				log("Old: " + (best.getPerformance().getMainCriterion().getFitness()));
				log("New: " + (qPerformance.getMainCriterion().getFitness()));
				if (qPerformance.compareTo(best.getPerformance()) > 0) {
					best = new ParameterSet(operators, parameters, Qvalues, qPerformance);
					// log
					log("Optimised parameter set does increase the performance");
				} else {
					// anderes log
					log("Could not increase performance by quadratic optimization");
				};
			} else {
				// not ok
				getLogger().warning("Parameters outside admissible range, not using optimised parameter set.");
			};
		} else {
			// Warning: no parameters to optimize
			getLogger().warning("No parameters to optimize");
		};
		// end quadratic optimization
		deliver(best);		
	}

	@Override
	public List<ParameterType> getParameterTypes() {
		List<ParameterType> types = super.getParameterTypes();
		types.add(new ParameterTypeCategory(PARAMETER_IF_EXCEEDS_REGION, "What to do if range is exceeded.", EXCEED_BEHAVIORS, CLIP));
		types.add(new ParameterTypeCategory(PARAMETER_IF_EXCEEDS_RANGE, "What to do if range is exceeded.", EXCEED_BEHAVIORS, FAIL));
		return types;
	}
}
