package com.rapidminer.operator.learner.meta;

import java.util.Iterator;

import sun.util.logging.resources.logging;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.tools.LogService;
import com.rapidminer.tools.Tools;

public abstract class AbstractWeightedPerformanceMeasures {

	/** This constant is used to express that no examples have been observed. */
	public static final double RULE_DOES_NOT_APPLY = Double.NaN;

	protected double[] predictions;

	protected double[] labels;

	protected double[][] pred_label;

	// The total number of examples without considering any weights.
	// Because of real-valued confidences these can also be real-valued:
	protected double[][] unweighted_num_pred_label;

	/**
	 * Constructor. Reads an example set, calculates its weighted performance
	 * values and caches them internally for later requests.
	 * 
	 * @param exampleSet
	 *            the <code>ExampleSet</code> this object shall hold the
	 *            performance measures for
	 */
	public AbstractWeightedPerformanceMeasures(ExampleSet exampleSet) throws OperatorException {
		{
			int numberOfClasses = exampleSet.getAttributes().getLabel().getMapping().getValues().size();
			this.labels = new double[numberOfClasses];

			// It is not necessary to interpret the result of the embedded
			// learner as predictions. Especially it not mandatory to have as
			// many "predictions" as labels. However, without any further information
			// let's assume the simple case, namely that the learner tries to 
			// predict the label with the result of the model:
			this.predictions = new double[numberOfClasses];

			// This array stores all combinations:
			this.pred_label = new double[this.predictions.length][this.labels.length];

			// The same array for unweighted examples:
			this.unweighted_num_pred_label = new double[this.predictions.length][this.labels.length];
		}
		
		initStatistics(exampleSet);
	}
	
	protected abstract void initStatistics(ExampleSet exampleSet);

	/**
	 * Method to query for the unweighted absolute number of covered examples of
	 * each class, given a specific prediction
	 * 
	 * @param prediction
	 *            the value predicted by the model (internal index number)
	 * @return an <code>int[]</code> array with the number of examples of
	 *         class <code>i</code> (internal index number) stored at index
	 *         <code>i</code>.
	 */
	public double[] getCoveredExamplesNumForPred(int prediction) {
		int length = this.unweighted_num_pred_label.length;
		if (prediction >= 0 && prediction < length) {
			return this.unweighted_num_pred_label[prediction];
		}
		else return new double[length]; // unknown prediction: no instances covered
	}

	/**
	 * @return the number of classes, namely different values of this object's
	 *         example set's label attribute
	 */
	public int getNumberOfLabels() {
		return this.labels.length;
	}

	/**
	 * @return number of predictions or nominal classes predicted by the
	 *         embedded learner. Not necessarily the same as the number of class
	 *         labels.
	 */
	public int getNumberOfPredictions() {
		return this.predictions.length;
	}

	/**
	 * Method to query for the probability of one of the prediction/label
	 * subsets
	 * 
	 * @param label
	 *            the (correct) class label of the example as it comes from the
	 *            internal index
	 * @param prediction
	 *            the boolean value predicted by the model (premise) (internal
	 *            index number)
	 * @return the joint probability of label and prediction
	 */
	public double getProbability(int label, int prediction) {
		return this.pred_label[prediction][label];
	}

	/**
	 * Method to query for the &quot;prior&quot; probability of one of the
	 * labels.
	 * 
	 * @param label
	 *            the nominal class label
	 * @return the probability of seeing an example with this label
	 */
	public double getProbabilityLabel(int label) {
		return this.labels[label];
	}

	/**
	 * Method to query for the &quot;prior&quot; probability of one of the
	 * predictions.
	 * 
	 * @param premise
	 *            the prediction of a model
	 * @return the probability of drawing an example so that the model makes
	 *         this prediction
	 */
	public double getProbabilityPrediction(int premise) {
		return this.predictions[premise];
	}

	/**
	 * The lift of the rule specified by the nominal variable's indices.
	 * <code>RULE_DOES_NOT_APPLY</code> is returned to indicate that no such
	 * example has ever been observed, <code>Double.POSITIVE_INFINITY</code>
	 * is returned if the class membership can deterministically be concluded
	 * from the prediction.
	 * 
	 * Important: In the multi-class case some of the classes might not be
	 * observed at all when a specific rule applies, but still the rule does not
	 * necessarily have a deterministic part. In this case the remaining number
	 * of classes is considered to be the complete set of classes when
	 * calculating the default values and lifts! This does not affect the
	 * prediction of the most likely class label, because the classes not
	 * observed have a probability of one, the other estimates increase
	 * proportionally. However, to calculate probabilities it is necessary to
	 * normalize the estimates in the class <code>BayBoostModel</code>.
	 * 
	 * @param label
	 *            the true label
	 * @param prediction
	 *            the predicted label
	 * @return the LIFT, which is a value >= 0, positive infinity if all
	 *         examples with this prediction belong to that class (deterministic
	 *         rule), or <code>RULE_DOES_NOT_APPLY</code> if no prediction can
	 *         be made.
	 */
	public double getLift(int label, int prediction) {
		double prLabel = this.getProbabilityLabel(label);
		double prPred = this.getProbabilityPrediction(prediction);
		double prJoint = this.getProbability(label, prediction);
		if (prPred == 0) {
			return RULE_DOES_NOT_APPLY;
		}
		else if (prJoint == 0) {
			return (0);
		} 
		else if (Tools.isEqual(prJoint, prPred)) {
			return (Double.POSITIVE_INFINITY);
		}

		double lift = prJoint / (prLabel * prPred);

		return lift;
	}

	/**
	 * The factor to be applied (pn-ratio) for each label if the model yields
	 * the specific prediction.
	 * 
	 * @param prediction
	 *            the predicted class
	 * @return a <code>double[]</code> array containing one factor for each
	 *         class. The result should either consist of well defined 
	 *         lifts >= 0, or all fields should mutually contain the constant
	 *         <code>RULE_DOES_NOT_APPLY</code>.
	 */
	public double[] getPnRatios(int prediction) {
		double[] lifts = new double[this.labels.length];
		for (int i = 0; i < lifts.length; i++) {
			int rapidMinerLabelIndex = i;
			double b = this.getLift(rapidMinerLabelIndex, prediction);
			if (b == 0 || b == Double.POSITIVE_INFINITY) {
				lifts[i] = b;
			} 
			else {
				// In this case the corresponding lift of the remaining classes
				// should also be defined. Using the odds avoids calculating the
				// probability of premises.
				double negLabel = 1 - this.getProbabilityLabel(rapidMinerLabelIndex);
				double probPred = this.getProbabilityPrediction(prediction);
				double probPredLabel = this.getProbability(rapidMinerLabelIndex, prediction);
				double negLabelPred = probPred - probPredLabel;

				double oppositeLift = negLabelPred / (negLabel * probPred);

				// What is stored is 
				// Lift( pred -> label) / Lift( pred -> neg(label) ):
				lifts[i] = b / oppositeLift;
			}
		}
		return lifts;
	}

	/**
	 * @return a matrix with one pn-factor per prediction/label combination, or
	 *         the priors of predictions for the case of soft base classifiers.
	 */
	public double[][] createLiftRatioMatrix() {
		int numPredictions = this.getNumberOfPredictions();
		double[][] liftRatioMatrix = new double[numPredictions][];
		for (int i = 0; i < numPredictions; i++) {
			liftRatioMatrix[i] = this.getPnRatios(i);
		}
		return liftRatioMatrix;
	}

	/**
	 * @return a <code>double[]</code> with the prior probabilities of all
	 *         class labels.
	 */
	public double[] getLabelPriors() {
		double[] priors = new double[this.getNumberOfLabels()];
		for (int i = 0; i < priors.length; i++) {
			priors[i] = this.getProbabilityLabel(i);
		}
		return priors;
	}

	/** 
	 * @return the number of classes with strictly positive weight 
	 */
	public int getNumberOfNonEmptyClasses() {
		int nonEmpty = 0;
		for (int i = 0; i < this.getNumberOfLabels(); i++) {
			if (this.getProbabilityLabel(i) > 0) {
				nonEmpty++;
			}
		}
		return nonEmpty;
	}

	/** converts the deprecated representation into the new form */
	public ContingencyMatrix getContingencyMatrix() {

		if (this.pred_label.length == 0 || this.pred_label[0].length == 0) {
			return new ContingencyMatrix(new double[0][0]); // doesn't make sense
		}

		double[][] matrix = new double[this.pred_label[0].length][this.pred_label.length];
		for (int i = 0; i < matrix.length; i++) {
			for (int j = 0; j < matrix[i].length; j++) {

				final double predLabelJi = this.pred_label[j][i]; 
				// Errors like this are hard to find, so this is worth a warning message: 
				if (Double.isNaN(predLabelJi) || predLabelJi < 0 || predLabelJi > 1) {
					LogService.getGlobal().log("Found illegal value in contingency matrix!", LogService.WARNING);
				}

				matrix[i][j] = predLabelJi;
			}
		}

		return new ContingencyMatrix(matrix);
	}

	/**
	 * Helper method of the <code>BayesianBoosting</code> operator
	 * 
	 * This method reweights the example set with respect to the
	 * <code>WeightedPerformanceMeasures</code> object. Please note that the
	 * weights will not be reset at any time, because they continuously change
	 * from one iteration to the next. This method does not change the priors of
	 * the classes.
	 * 
	 * @param exampleSet
	 *            <code>ExampleSet</code> to be reweighted
	 * @param cm
	 *            the <code>ContingencyMatrix</code> as e.g. returned by
	 *            <code>WeightedPerformanceMeasures</code>
	 * @param allowMarginalSkews
	 * 		      indicates whether the weight of covered and uncovered subsets
	 *            are allowed to change. 
	 * @return the total weight
	 */
	public static double reweightExamples(ExampleSet exampleSet, ContingencyMatrix cm, boolean allowMarginalSkews, boolean fuzzy)
	throws OperatorException
	{
		Iterator<Example> reader = exampleSet.iterator();
		double totalWeight = 0;

		Attribute labelAttribute = exampleSet.getAttributes().getLabel();
		Attribute predictedLabel = exampleSet.getAttributes().getPredictedLabel();
		Attribute weightAttribute = exampleSet.getAttributes().getWeight();

		while (reader.hasNext()) {
			Example example = reader.next();
			int label = (int) example.getValue(labelAttribute);

			int predicted = (int) example.getValue(predictedLabel);
			double lift = getLiftForExample(example, cm);

			if (Double.isNaN(lift) || lift < 0) {
				// == RULE_DOES_NOT_APPLY || serious error
				LogService.getGlobal().log("Applied rule with an illegal lift of "
						+ lift + " during reweighting!", LogService.WARNING);
			}
			else if (lift == 0 || Double.isInfinite(lift)) {
				// In both cases the model predicts deterministically, so we can
				// remove the example from further investigation.
				// lift = 0: model misclassifies, cannot happen for the original
				// training set, but in other contexts
				// Infinite: model classifies correctly
				example.setValue(weightAttribute, 0);
			}
			else {
				// this is the normal setting, just make sure that the weights are ok 

				double weight = example.getValue(weightAttribute);
				double newWeight;

				if (Double.isNaN(weight) || Double.isInfinite(weight) || weight < 0) {
					// Infinite, NaN, and negative weights cannot be processed any further
					// in a meaningful way!
					LogService.getGlobal().log("Found illegal weight: " + weight, LogService.WARNING);
					newWeight = 0; // try to continue anyway
				}
				else if (weight == 0) {
					// nothing to do
					continue;
				}
				else if (allowMarginalSkews) {
					if (fuzzy) {
						String positiveClass = labelAttribute.getMapping().getPositiveString();
						String negativeClass = labelAttribute.getMapping().getNegativeString();
						int positiveIdx = labelAttribute.getMapping().getPositiveIndex();
						int negativeIdx = labelAttribute.getMapping().getNegativeIndex();
						double positiveConfidence = example.getConfidence(positiveClass);
						double negativeConfidence = example.getConfidence(negativeClass);
						
						double prec_pos = cm.getPrecision(label, positiveIdx);
						double prec_neg = cm.getPrecision(label, negativeIdx);
						
						double beta_pos = (1-prec_pos)/prec_pos;
						double beta_neg = (1-prec_neg)/prec_neg;
						
						if ( predicted == negativeIdx && (Double.isInfinite(beta_pos) || Double.isNaN(beta_pos)) )
						{
							//fall back to unfuzzy weights
							negativeConfidence = 1;
							positiveConfidence = 0;
							beta_pos = 1;
						}
						else if ( predicted == positiveIdx && (Double.isInfinite(beta_neg) || Double.isNaN(beta_neg)) )
						{
							//fall back to unfuzzy weights
							negativeConfidence = 0;
							positiveConfidence = 1;
							beta_neg = 1;
						}
//						
						// Sanity check: prec > 0 because lift > 0, beta has to be a regular double >= 0 
						if (beta_pos <= 0 || beta_neg <= 0 || Double.isInfinite(beta_pos) || Double.isNaN(beta_pos) || Double.isInfinite(beta_neg) || Double.isNaN(beta_neg)) {
							LogService.getGlobal().log(("Reweighting uses invalid value: predicted is " + predicted 
									+ ", beta_pos is " + beta_pos + ", beta_neg is " + beta_neg
									), LogService.WARNING);
						}
						newWeight = weight * (  positiveConfidence * Math.sqrt(beta_pos) + negativeConfidence * Math.sqrt(beta_neg)  );
					}
					else {
						double prec = cm.getPrecision(label, predicted); // ~ Acc = 1 - epsilon
						double invPrec = 1 - prec; // epsilon

						double beta = invPrec / prec; // beta = epsilon / ( 1 - epsilon)
						
						// Sanity check: prec > 0 because lift > 0, beta has to be a regular double >= 0 
						if (prec <= 0 || invPrec < 0 || Double.isInfinite(beta) || Double.isNaN(beta)) {
							LogService.getGlobal().log(("Reweighting uses invalid value:"
									+ "Precision is " + prec + ", inverse precision is " + invPrec
									+ ", beta is " + beta), LogService.WARNING);
						}
						newWeight = weight * Math.sqrt(beta);					
					}

				}
				else {
					newWeight = weight / lift;
				}

				// set the new weight and remember the 
				example.setValue(weightAttribute, newWeight);
				totalWeight += newWeight;
			}
		}
		return totalWeight;
	}

	private static double getLiftForExample(Example example,
			ContingencyMatrix cm) {
		int label = (int) example.getLabel();

		int predicted = (int) example.getPredictedLabel();

		Attribute labelAttribute = example.getAttributes().getLabel();
		String predictedClass = labelAttribute.getMapping().getValues().get(  predicted);
		String otherClass     = labelAttribute.getMapping().getValues().get(1-predicted);

		
		double predictedConfidence = example.getConfidence(predictedClass);
		double otherConfidence     = example.getConfidence(otherClass);

		double liftPredicted = cm.getLift(label, predicted);
		double liftOther = cm.getLift(label, 1-predicted);
		if ( Double.isNaN(liftOther) ) {
			liftOther = 0;
			predictedConfidence = 1;
		}
		double lift = liftPredicted * predictedConfidence + liftOther * otherConfidence;
		return lift;
	}
}
