package com.rapidminer.operator.learner.meta;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.ExecutionUnit;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.operator.ports.metadata.PredictionModelMetaData;
import com.rapidminer.operator.ports.metadata.SimplePrecondition;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.Ontology;
import com.rapidminer.tools.RandomGenerator;

public abstract class AbstractBayesianBoosting extends AbstractMetaLearner {

	protected final InputPort modelInput = getInputPorts().createPort("model");
	/** Field for visualizing performance. */
	protected int currentIteration;
	/** Subprocess 0 darf nicht verwendet werden, weil AbstractMetaLearner
	   in Subprocess 0 den eigentlichen Lerner erwartet. */
	protected final int GET_RANDOM_SAMPLE_SUBPROCESS = 1;
	protected InputPort innerRandomSamplePort;
	/** A backup of the original weights of the training set to restore them 
	 *  after learning. */
	protected double[] oldWeights;
	/** A performance measure to be visualized. */
	double performance = 0;
	/** A model to initialise the example weights. */
	protected Model startModel;
	
	/**
	 * Boolean parameter to specify whether the label priors should be equally
	 * likely after first iteration.
	 */
	public static final String PARAMETER_RESCALE_LABEL_PRIORS = "rescale_label_priors";

	/** Name of the flag indicating internal bootstrapping. */
	public static final String PARAMETER_USE_SUBSET_FOR_TRAINING = "use_subset_for_training";

	/**
	 * Boolean parameter that switches between KBS (if set to false) and a
	 * boosting-like reweighting.
	 */
	public static final String PARAMETER_ALLOW_MARGINAL_SKEWS = "allow_marginal_skews";


	public static final String PORT_INNER_EXAMPLE_SET = "Example Set";

	
	public static final String PARAMETER_FUZZY_PARTITION_SIZES = "fuzzy_partition_sizes";	
	public static final String PARAMETER_FUZZY_EXAMPLE_REWEIGHTING = "fuzzy_reweighting";

	protected Constructor pmConstructor;
	protected Class pmClass;
	protected boolean fuzzyReweighting;


	
	public AbstractBayesianBoosting(OperatorDescription description) {
		super(description);

		modelInput.addPrecondition(new SimplePrecondition(modelInput, new PredictionModelMetaData(PredictionModel.class, new ExampleSetMetaData()), false));

		addValue(new ValueDouble("iteration", "The current iteration.") {
			@Override
			public double getDoubleValue() {
				return currentIteration;
			}
		});
		initializeSubprocesses();
	}
	
	protected abstract BayBoostModel trainBoostingModel(ExampleSet trainingSet, final double[] classPriors) 
			throws OperatorException;
	
	protected void initializeSubprocesses() {
		ExecutionUnit getRandomSample = addSubprocess(GET_RANDOM_SAMPLE_SUBPROCESS);
		
		getRandomSample.setName("fetch random sample");
		
		innerRandomSamplePort = getSubprocess(GET_RANDOM_SAMPLE_SUBPROCESS).getInnerSinks().createPort(PORT_INNER_EXAMPLE_SET, ExampleSet.class);
	}


	@Override
	/**
	 * Adding weight attributes
	 */
	protected MetaData modifyExampleSetMetaData(ExampleSetMetaData unmodifiedMetaData) {
		AttributeMetaData weightAttribute = new AttributeMetaData("weight", Ontology.REAL, Attributes.WEIGHT_NAME);
		unmodifiedMetaData.addAttribute(weightAttribute);
		return super.modifyExampleSetMetaData(unmodifiedMetaData);
	}

	/**
	 * Overrides the method of the super class. Returns true for polynominal
	 * class.
	 */
	@Override
	public boolean supportsCapability(OperatorCapability lc) {
		switch (lc) {
		//case NUMERICAL_LABEL:
		//case POLYNOMINAL_LABEL:
		case NO_LABEL:
		case UPDATABLE:
		case FORMULA_PROVIDER:
			return false;
		default:
			return true;
		}
	}

	/**
	 * Override the doWork() method to initialize the exampleSet as the super method
	 * expects it
	 */
	@Override
	public void doWork() throws OperatorException {
		// if example set input is empty, retrieve a new example set
		// from inner operator:
		if ( exampleSetInput.getDataOrNull() == null ) {
			ExampleSet exampleSet = getNewSample( null );
			exampleSetInput.receive(exampleSet);
		}
		super.doWork();
	}

	/**
	 * Draws a new sample by calling the corresponding subprocess.
	 * 
	 * Initializes the weights of the examples by subsequently applying
	 * all base models to it and adjusting the example weights accordingly. 
	 * 
	 * @return the newly drawn example set.
	 * @throws OperatorException
	 * @throws UserError
	 */
	protected ExampleSet getNewSample( Vector<BayBoostBaseModelInfo> modelInfo )
			throws OperatorException, UserError 
	{
		// run inner process and retrieve sample:
		ExampleSet trainingSet;
		ExecutionUnit samplingSubprocess = getSubprocess(GET_RANDOM_SAMPLE_SUBPROCESS);
		samplingSubprocess.execute();
		trainingSet = innerRandomSamplePort.getData(ExampleSet.class);
		
		// create new / backup old weight attribute:
		prepareWeights(trainingSet);
		
		applyPriorModel(trainingSet, null);

		boolean bootstrap = getBootstrap();
		if ( bootstrap ) {
			// TODO getNewSample mit Bootstrap
			return null;
		}
		else {
			if ( modelInfo != null ) {
				for ( BayBoostBaseModelInfo current : modelInfo ) {
					Model model = current.getModel();
					trainingSet = model.apply( trainingSet );
					reweightExamplesWrapper( trainingSet, bootstrap );
				}
			}
			return trainingSet;
		}
	}
	/**
	 * Creates a weight attribute if not yet done. It either backs up the old
	 * weoghts for restoring them later, or it fills the newly created attribute
	 * with the initial value of 1. If rescaling to equal class priors is
	 * activated then the weights are set accordingly (but only if the example set
	 * does not yet contain a weight attribute).
	 * 
	 * TODO Rescale label priors, even if weights are present?
	 * 
	 * @param exampleSet
	 *            the example set to be prepared
	 * @return a <code>double[]</code> array containing the class priors.
	 */	
	protected double[] prepareWeights(ExampleSet exampleSet) {
		Attribute weightAttr = exampleSet.getAttributes().getWeight();
		if (weightAttr == null) {
			this.oldWeights = null;

			// example weights are initialized so that the total weight
			// is equal to the number of examples:
			this.performance = exampleSet.size();

			return this.createNewWeightAttribute(exampleSet);
		}
		else { 
			// Back up old weights and compute priors:
			this.oldWeights = new double[exampleSet.size()];
			double[] priors = new double[exampleSet.getAttributes().getLabel().getMapping().size()];
			double totalWeight = 0;
			Iterator<Example> reader = exampleSet.iterator();

			for (int i = 0; (reader.hasNext() && i < oldWeights.length); i++) {
				Example example = reader.next();
				if (example != null) {
					double weight = example.getWeight();
					this.oldWeights[i] = weight;
					int label = (int) example.getLabel();

					if (0 <= label && label < priors.length) {
						priors[label] += weight;
						totalWeight += weight;
					} 
					else example.setWeight(0); // Unrecognized label, try to ignore it!
				}
			}
			this.performance = totalWeight;

			// Normalize:
			for (int i = 0; i < priors.length; i++) {
				priors[i] /= totalWeight;
			}

			return priors;
		}
	}

	/**
	 * 
	 * @param exampleSet
	 * @return
	 */
	private double[] createNewWeightAttribute(ExampleSet exampleSet) {
		com.rapidminer.example.Tools.createWeightAttribute(exampleSet);

		Iterator<Example> exRead = exampleSet.iterator();
		int numClasses = exampleSet.getAttributes().getLabel().getMapping().getValues().size();
		double[] classPriors = new double[numClasses];

		int total = exampleSet.size();
		double invTotal = 1.0d / total;

		if (this.getParameterAsBoolean(PARAMETER_RESCALE_LABEL_PRIORS) == false) {
			while (exRead.hasNext()) {
				Example example = exRead.next();
				example.setWeight(1);
				classPriors[(int) (example.getLabel())] += invTotal;
			}
		} 
		else { 
			// first count the class frequencies
			while (exRead.hasNext()) {
				classPriors[(int) (exRead.next().getLabel())] += invTotal;
			}
			this.rescaleToEqualPriors(exampleSet, classPriors);
		}
		return classPriors;
	}

	/**
	 * Helper method applying the start model and adding it to the modelInfo
	 * collection. Also reweights the exampleSet.
	 * 
	 * @param modelInfo the modelInfo list to which the weight of the startModel is added.
	 * 			Will be ignored if null.
	 */
	protected void applyPriorModel(ExampleSet trainingSet, List<BayBoostBaseModelInfo> modelInfo) throws OperatorException {
		// If the input contains a model already, initialise the example weights.
		if (this.startModel != null) {

			ExampleSet resultSet = this.startModel.apply((ExampleSet)trainingSet.clone());

			// Initial values and the input model are stored in the output model.
//			WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(resultSet);
			
			AbstractWeightedPerformanceMeasures wp;
			try {
				wp = (AbstractWeightedPerformanceMeasures)pmConstructor.newInstance(resultSet);
			} catch (Exception e) {
				throw new OperatorException("cannot call reweightExamples"); 
			}


			this.reweightExamples(wp, resultSet);
			if ( modelInfo != null ) {
				modelInfo.add(new BayBoostBaseModelInfo(this.startModel, wp.getContingencyMatrix()));
			}
			PredictionModel.removePredictedLabel(resultSet);
		}
	}
	
	private boolean getBootstrap() throws UndefinedParameterError
	{
		final double splitRatio = this.getParameterAsDouble(PARAMETER_USE_SUBSET_FOR_TRAINING);
		final boolean bootstrap = ((splitRatio > 0) && (splitRatio < 1.0));
		return bootstrap;
	}
	
	/**
	 * Reweights the exampleSet, which must contain a predictedLabelAttribute.
	 * 
	 * It does "the right thing" depending on the bootstrap parameter.
	 * 
	 * @param exampleSet
	 * @param bootstrap
	 * @return
	 * @throws OperatorException
	 */
	protected AbstractWeightedPerformanceMeasures reweightExamplesWrapper(ExampleSet exampleSet, boolean bootstrap ) 
			throws OperatorException 
	{
		if ( bootstrap ) {
			SplittedExampleSet splittedSet = (SplittedExampleSet)exampleSet;
			
			AbstractWeightedPerformanceMeasures wp;
			try {
				wp = (AbstractWeightedPerformanceMeasures)pmConstructor.newInstance(splittedSet);
				this.performance = (Double)pmClass.getMethod(
						"reweightExamples", 
						ExampleSet.class, 
						ContingencyMatrix.class, 
						boolean.class, 
						boolean.class).invoke(
								null, 
								splittedSet, 
								wp.getContingencyMatrix(), 
								getParameterAsBoolean(PARAMETER_ALLOW_MARGINAL_SKEWS), 
								fuzzyReweighting);
			} catch (Exception e) {
				throw new OperatorException("cannot call reweightExamples"); 
			}

		
			// handle test set: reweight it separately, use its estimates
			// for future predictions
			splittedSet.selectSingleSubset(1);
			try {
				wp = (AbstractWeightedPerformanceMeasures)pmConstructor.newInstance(splittedSet);
			} catch (Exception e) {
				throw new OperatorException("cannot call reweightExamples"); 
			}
			return wp;
		}
		else {
			// get the weighted performance value of the example set with
			// respect to the model
			AbstractWeightedPerformanceMeasures wp;
			try {
				// get the weighted performance value of the example set with
				// respect to the model
				wp = (AbstractWeightedPerformanceMeasures)pmConstructor.newInstance(exampleSet);
			} catch (Exception e) {
				throw new OperatorException("cannot call reweightExamples"); 
			}

			// Reweight the example set with respect to the weighted
			// performance values. The return value is the total weight
			// of the example set as a performance measure:
			this.performance = this.reweightExamples(wp, exampleSet);

			return wp;			
		}
	}


	private void rescaleToEqualPriors(ExampleSet exampleSet, double[] currentPriors) {
		// The weights of class i are calculated as
		// (1 / #classes) / (#rel_freq_class_i)
		double[] weights = new double[currentPriors.length];
		for (int i = 0; i < weights.length; i++) {
			weights[i] = 1.0d / (weights.length * (currentPriors[i]));
		}

		Iterator<Example> exRead = exampleSet.iterator();
		while (exRead.hasNext()) {
			Example example = exRead.next();
			example.setWeight(weights[(int) (example.getLabel())]);
		}
	}

	/**
	 * This method reweights the example set with respect to the
	 * <code>WeightedPerformanceMeasures</code> object. Please note that the
	 * weights will not be reset at any time, because they continuously change
	 * from one iteration to the next. This method does not change the priors of
	 * the classes.
	 * 
	 * @param wp
	 *            the WeightedPerformanceMeasures to use
	 * @param exampleSet
	 *            <code>ExampleSet</code> to be reweighted
	 * @return the total weight of examples as an error estimate
	 */
	protected double reweightExamples(AbstractWeightedPerformanceMeasures wp, ExampleSet exampleSet)
	throws OperatorException
	{
		boolean allowMarginalSkews = this.getParameterAsBoolean(PARAMETER_ALLOW_MARGINAL_SKEWS);
		double remainingWeight;
		try {
			remainingWeight = (Double)pmClass.getMethod(
					"reweightExamples", 
					ExampleSet.class, 
					ContingencyMatrix.class, 
					boolean.class, 
					boolean.class).invoke(
							null, 
							exampleSet, 
							wp.getContingencyMatrix(), 
							allowMarginalSkews, 
							fuzzyReweighting);
		} catch (Exception e) {
			throw new OperatorException("cannot call reweightExamples"); 
		}


		return remainingWeight;
	}

	/** Helper method reading a start model from the input if present. 
	 * @throws UserError */
	protected void readOptionalParameters() throws UserError {
		this.startModel = modelInput.getDataOrNull();
		if (this.startModel == null) {
			log(getName() + ": No model found in input.");
		}
	}

	/**
	 * Runs the &quot;embedded&quot; learner on the example set and returns a
	 * model.
	 * 
	 * @param exampleSet
	 *            an <code>ExampleSet</code> to train a model for
	 * @return a <code>Model</code>
	 */
	protected Model trainBaseModel(ExampleSet exampleSet) throws OperatorException {
		Model model = applyInnerLearner(exampleSet);
		return model;
	}


	/**
	 * Adds the parameters &quot;number of iterations&quot; and &quot;model
	 * file&quot;.
	 */
	@Override
	public List<ParameterType> getParameterTypes() {
		List<ParameterType> types = super.getParameterTypes();
		ParameterType type = new ParameterTypeDouble(PARAMETER_USE_SUBSET_FOR_TRAINING, "Fraction of examples used for training, remaining ones are used to estimate the confusion matrix. Set to 1 to turn off test set.", 0, 1, 1);
		type.setExpert(false);
		types.add(type);

		types.add(new ParameterTypeBoolean(PARAMETER_RESCALE_LABEL_PRIORS, "Specifies whether the proportion of labels should be equal by construction after first iteration .", false));
		types.add(new ParameterTypeBoolean(PARAMETER_ALLOW_MARGINAL_SKEWS, "Allow to skew the marginal distribution (P(x)) during learning.", true));

		types.add(new ParameterTypeBoolean(PARAMETER_FUZZY_EXAMPLE_REWEIGHTING, "Specifies whether the example weights should calculated in a fuzzy way.", false));
		types.add(new ParameterTypeBoolean(PARAMETER_FUZZY_PARTITION_SIZES, "Specifies if the counting of tp, np etc is based on confidences instead of crisp predictions.", false));

		types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
		
		return types;
	}

}
