package Bachelorpackage;

/*
 * Copyright 2017 Andreas Sitta
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *  http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import java.io.BufferedWriter;
import java.io.File;
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;

import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import com.rapidminer.kobra.topicmodels.SamplersLDA;

import gnu.trove.list.array.TIntArrayList;

/**
 * Das Link-Content Modell basierend auf Algorithmus aus: 
 * Natarajan, Nagarajan, Prithviraj Sen, and Vineet Chaoji. "Community detection in content-sharing social networks." 
 * Proceedings of the 2013 IEEE/ACM International Conference on Advances in Social Networks Analysis and Mining. ACM, 2013.
 * 
 * Hier wird das Link-Content Modell angewendet. Dies basiert auf den in Part2ReducedToInput erstellten Daten.
 * Wichtig: falls Perplexit�t berechnet werden soll muss perplexity=true. Dann wird der Trainings und Testdatensatz verwendet.
 * @author Andreas Sitta
 * @version 1.0
 *
 */
public class Part3LinkContent {

	static Gson gson = new Gson();
	static ArrayList<User> users;
	static ArrayList<TIntArrayList> friendshipcomms = new ArrayList<TIntArrayList>(); //Communitys der Freundschaften
	static double[][] delta, theta, phi, psi;
	static int[][] Muk, Mkz, Fuk, Fkv, Nzw;
	static int[] Mu, Mk, Fu, Fk, Nz;
	static double[] nu, my, alpha, beta;
	static double nuSum, mySum, alphaSum, betaSum;
	static int numWords, numUsers;
    static BufferedWriter bW3;
	static String p2, dataName = "", dName = "", pName = "", pLDAName = "";
	
	//******** Parameter zum Testen
	static int numComms =5;//Anzahl der Communitiys
	static int numTopics = 5; //Anzahl Themen
	static int burnedSamples = 50;//10
	static int collectedSamples = 50;//5
	static int skippedSamples = 1; //werden pro collected geskipped
	static double nuinit = 0.01; //Wert für nu = 0.01
	static double myinit = 0.01 ; //Wert für my = 0.01
	static double alphainit = 0.01; //Wert für Alpha = 0.01 
	static double betainit = 0.01; //Wert für Beta = 0.01
	static boolean perplexity =false; 
	static int seed = 2000;
	static Random rn = new Random(seed);
	
    
	/**
	 * Ausf�hren des Algorithmus mit 
	 * @param args derzeit nicht verwendet
	 * @throws IOException
	 * @throws NegativeProbabilityException
	 */
	@SuppressWarnings(value = "unchecked")
	public static void main(String[] args) throws IOException, NegativeProbabilityException {
		//Protokoll von Part 2 in Protokoll Part3 kopieren
		getFileN("output/Part2.txt");
		dName = "output/"+dataName+"/";
		pLDAName = "LDA-"+pName;
		pName+= "-c"+numComms+"-t"+numTopics+"-b"+burnedSamples+"-cs"+collectedSamples+"-ss"+skippedSamples;

		MyTools.copyFile("output/Part2.txt","output/Part3.txt",dataName+" "+pName);
		bW3 = new BufferedWriter(new FileWriter("output/Part3.txt",true));		
		
		bW3.write("\r\n\r\n**********************************   Part 3   ****************************************\r\n\r\n");
        
		//DATEN LESEN
		System.out.println("Laden");
		if (perplexity) {
			System.out.println("Perplexity: On");
			users = gson.fromJson(new FileReader("2train.json"), new TypeToken<ArrayList<User>>() {
			}.getType());
		} else {
			System.out.println("Perplexity: Off");
			users = gson.fromJson(new FileReader("2users.json"), new TypeToken<ArrayList<User>>() {
			}.getType());
		
		}

		numWords = ((ArrayList<String>) gson.fromJson(new FileReader("1vocabulary.json"), new TypeToken<ArrayList<String>>() {}.getType())).size();
		numUsers = users.size();
		//}DATEN EINGELESEN		

		//{VARIABLEN ERSTELLEN
		delta = new double[numUsers][numComms];
		theta = new double[numComms][numTopics];
		phi = new double[numTopics][numWords];
		psi = new double[numComms][numUsers];
		//initialwerte aus Link-Content Artikel
		nu = new double[numComms];
		Arrays.fill(nu, nuinit);
		my = new double[numUsers];
		Arrays.fill(my, myinit);
		alpha = new double[numTopics];
		Arrays.fill(alpha, alphainit);
		beta = new double[numWords];
		Arrays.fill(beta, betainit);
		for (int i = 0; i < numUsers; i++) {
			friendshipcomms.add(new TIntArrayList()); //um nacher einer Freundschaftskante eine Community zuzuweisen
		}
		//}VARIABLEN ERSTELLEN

		//{Sampling
		final long timeStart = System.currentTimeMillis(); //Zeit stoppen
		init();

		if (perplexity) {
			runPerplexity(burnedSamples, collectedSamples, skippedSamples); //burned, collected, skipped
			save(); 
		} else {
			runSampling(burnedSamples, collectedSamples, skippedSamples);
			save();
		}
		final long timeEnd = System.currentTimeMillis(); //Zeit stoppen
		System.out.println("Zeit: " + (timeEnd - timeStart));
		//}Sampling
		bW3.close();
	}

	/**
	 * Hier wird das Sampling durchgef�hrt
	 * @param burnedSamples Anzahl der verworfenen Samples
	 * @param collectedSamples Anzahl der zu sammelnden Samples
	 * @param skippedSamples Anzahl der zu verwerfenden Samples nach jedem zu Sammelndem Sample. Achtung: erh�ht Laufzeit um Faktor "skippedSamples"
	 * @throws NegativeProbabilityException Zur Fehlersuche: Negative Wahrscheinlichkeiten sollten allerdings nicht vorkommen
	 */
	public static void runSampling(int burnedSamples, int collectedSamples, int skippedSamples) throws NegativeProbabilityException {
		for (int i = 1; i <= burnedSamples; i++)
			Sampling(false, "burn-in/skipped: " + i + " ");
		for (int i = 1; i <= collectedSamples; i++) {
			for (int j = 1; j <= skippedSamples; j++) {
				Sampling(false, "skip: " + j + " ");
			}
			Sampling(true, "sample: " + i + " ");
		}
	}

	/**
	 * Initialisierung: Freundschaften erhalten zuf�llige Community, Dokumente zuf�lliges Community-Topic Paar
	 */
	public static void init() {
		System.out.println("Init");
		Muk = new int[numUsers][numComms];
		Mu = new int[numUsers];
		Mk = new int[numComms];
		Mkz = new int[numComms][numTopics];
		Fuk = new int[numUsers][numComms];
		Fu = new int[numUsers];
		Fk = new int[numComms];
		Fkv = new int[numComms][numUsers];
		Nzw = new int[numTopics][numWords];
		Nz = new int[numTopics];

		for (int i = 0; i < numUsers; i++) {// Mu: Dokumente von u
			Mu[i] = users.get(i).docs.size();
		}
		for (int i = 0; i < numUsers; i++) { //Fu: Freunde von u
			Fu[i] = users.get(i).friends.size();
		}
		for (int i = 0; i < numComms; i++) { //NuSum
			nuSum += nu[i];
		}
		for (int i = 0; i < numUsers; i++) { //MySum
			mySum += my[i];
		}
		for (int i = 0; i < numTopics; i++) { //AlphaSum
			alphaSum += alpha[i];
		}
		for (int i = 0; i < numWords; i++) { //BetaSum
			betaSum += beta[i];
		}

		for (int u = 0; u < numUsers; u++) { //F�R ALLE USER:
			for (int v = 0; v < users.get(u).friends.size(); v++) { //F�R ALL SEINE FREUNDE:
				int friendID = users.get(u).friends.get(v);
				int pickedComm = rn.nextInt(numComms);
				Fuk[u][pickedComm]++; //Zähler erhöhen
				Fk[pickedComm]++; //Zähler erhöhen
				Fkv[pickedComm][friendID]++;//Zähler erhöhen 
				friendshipcomms.get(u).add(pickedComm); //Zuf�llige Community
			}

			for (int d = 0; d < users.get(u).docs.size(); d++) {
				Document currentDoc = users.get(u).docs.get(d);
				//Community zuweisen
				int pickedComm = rn.nextInt(numComms);
				currentDoc.comm = pickedComm;
				Muk[u][pickedComm]++;
				Mk[pickedComm]++;

				//Thema zuweisen
				int pickedTopic = rn.nextInt(numTopics);
				currentDoc.topic = pickedTopic;
				Mkz[pickedComm][pickedTopic]++;
				for (int w = 0; w < currentDoc.words.size(); w++) {
					Nzw[pickedTopic][currentDoc.words.get(w).id] += currentDoc.words.get(w).freq;
					Nz[pickedTopic] += currentDoc.words.get(w).freq;
				}
			}
		}
	}

	
	/**
	 * Samplingdurchlauf
	 * @param collect gibt an, ob das Sample verworfen wird oder nicht
	 * @param str
	 * @throws NegativeProbabilityException Zur Fehlersuche: Negative Wahrscheinlichkeiten sollten allerdings nicht vorkommen
	 */
	public static void Sampling(boolean collect, String str) throws NegativeProbabilityException {
		double[] friendprobs = new double[numComms]; //friendprobs: Wahrscheinlichkeitsverteilung, dass Freundschaft Community zugeordnet wird
		double[][] docprobs = new double[numComms][numTopics]; //Wahrscheinlichkeit einem Dokument ein Community-Topic-Paar zuzuordnen

		for (int u = 0; u < numUsers; u++) { //F�R ALLE USER:
			if (u % 100 == 0)
				System.out.println(str + " User: " + u + "/" + numUsers);
			for (int v = 0; v < users.get(u).friends.size(); v++) { //F�R ALL SEINE FREUNDE:
				int friendID = users.get(u).friends.get(v);
				double totprob = (double) 0;
				for (int k = 0; k < numComms; k++) {
					//Testen ob wirklich um 1 dekrementiert werden muss (- im Superskript)
					int sub1 = 0;
					if (friendshipcomms.get(u).get(v) == k)
						sub1 = 1; 

					friendprobs[k] = ((Muk[u][k] + Fuk[u][k] - sub1 + nu[k]) / (Mu[u] + Fu[u] - 1 + nuSum)) * ((Fkv[k][friendID] - sub1 + my[friendID]) / (Fk[k] - (Fkv[k][u]+ Fuk[u][k]) + mySum));
					if (friendprobs[k] < 0)
						throw new NegativeProbabilityException("u=" + u + " ; v=" + v + " ; k=" + k + "      " + "Muk:" + Muk[u][k] + " Fuk:" + Fuk[u][k] + " sub1:" + sub1 + " Mu:" + Mu[u] + " Fu:" + Fu[u] + " Fkv:" + Fkv[k][friendID] + " Fk:" + Fk[k]);
					totprob += friendprobs[k];
				}

				//Sample ziehen aus Wahrscheinlichkeitsverteilung
				double r = (double) totprob * rn.nextDouble();
				double max = friendprobs[0];
				int pickedComm = 0;
				while (max < r) {
					pickedComm++;
					max += friendprobs[pickedComm];
				}
				//Z�hler dekrementieren
				Fuk[u][friendshipcomms.get(u).get(v)]--;
				Fk[friendshipcomms.get(u).get(v)]--;
				Fkv[friendshipcomms.get(u).get(v)][friendID]--;
				//Z�hler inkrementieren
				Fuk[u][pickedComm]++;
				Fk[pickedComm]++;
				Fkv[pickedComm][friendID]++;
				friendshipcomms.get(u).replace(v, pickedComm); //Zuweisung
			}

			for (int d = 0; d < users.get(u).docs.size(); d++) {
				Document currentDoc = users.get(u).docs.get(d);
				double totprob = (double) 0;
				for (int k = 0; k < numComms; k++) {
					for (int z = 0; z < numTopics; z++) {
						double lastfraction = 1;
						double wordSumn_dw = 0;

						int NzSub = 0;
						if (currentDoc.topic == z) {
							for (int w = 0; w < currentDoc.words.size(); w++) {
								NzSub += currentDoc.words.get(w).freq;
							}
						}
						int NzwSub = 0;
						for (int w = 0; w < currentDoc.words.size(); w++) {
							if (currentDoc.topic == z) {
								NzwSub = currentDoc.words.get(w).freq;
							}

							for (int i = 1; i <= currentDoc.words.get(w).freq; i++) {
								lastfraction = lastfraction * ((Nzw[z][currentDoc.words.get(w).id] - NzwSub + i - 1 + beta[currentDoc.words.get(w).id]) / (Nz[z] - NzSub + (wordSumn_dw + w*(i-1)) + betaSum)); 
							}
							//Summe f�r unter dem Bruch des 3. Terms
							wordSumn_dw += currentDoc.words.get(w).freq;
						}
					
						//Testen ob wirklich um 1 dekrementiert werden muss (- im Superskript)
						int sub1 = 0, MkzSub = 0;
						if (currentDoc.comm == k) {
							sub1 = 1;
							if (currentDoc.topic == z)
								MkzSub = 1;
						}
						docprobs[k][z] = ((Muk[u][k] - sub1 + Fuk[u][k] + nu[k]) / (Mu[u] - 1 + Fu[u] + nuSum) * ((Mkz[k][z] - MkzSub + alpha[z]) / (Mk[k] - sub1 + alphaSum)) * lastfraction);
						if (docprobs[k][z] < 0)
							throw new NegativeProbabilityException(new String("docprobs[" + k + "][" + z + "] ; d=" + d + " ; " + "u=" + u));
						totprob += docprobs[k][z];
					}
				}

				//Sample ziehen aus Wahrscheinlichkeitsverteilung
				double r = (double) totprob * rn.nextDouble();
				double max = docprobs[0][0];
				int pickedComm = 0;
				int pickedTopic = 0;
				loop: for (int k = 0; k < numComms; k++) {
					for (int z = 0; z < numTopics; z++) {
						if (max<r) {
							pickedComm = k;
							pickedTopic = z;
							max += docprobs[k][z];
						} else {
							break loop;
						}
					}
				}

				//Z�hler dekrementieren
				Muk[u][currentDoc.comm]--;
				Mk[currentDoc.comm]--;
				Mkz[currentDoc.comm][currentDoc.topic]--;
				//Z�hler inkrementieren
				Muk[u][pickedComm]++;
				Mk[pickedComm]++;
				Mkz[pickedComm][pickedTopic]++;

				//Wortz�hler
				for (int w = 0; w < currentDoc.words.size(); w++) {
					Nzw[currentDoc.topic][currentDoc.words.get(w).id] -= currentDoc.words.get(w).freq;
					Nz[currentDoc.topic] -= currentDoc.words.get(w).freq;
					Nzw[pickedTopic][currentDoc.words.get(w).id] += currentDoc.words.get(w).freq;
					Nz[pickedTopic] += currentDoc.words.get(w).freq;
				}
				//Zuweisung
				currentDoc.comm = pickedComm; //Zuweisung
				currentDoc.topic = pickedTopic; //Zuweisung
			}
		}
		if (collect)
			collectSample(Muk, Mkz, Fuk, Fkv, Fk, Nzw, Mk, Nz);
	}

	
	static int count = 1; //F�r die aktualisierung des Durchschnittes der latent Variables
	
	/**
	 * Sammelt die Samples und berechnet den Durchschnitt, aktualisiert also delta, theta, phi, psi
	 * @param Muk
	 * @param Mkz
	 * @param Fuk
	 * @param Fkv
	 * @param Fk
	 * @param Nzw
	 * @param Mk
	 * @param Nz
	 */
	public static void collectSample(int[][] Muk, int[][] Mkz, int[][] Fuk, int[][] Fkv, int[] Fk, int[][] Nzw, int[] Mk, int[] Nz) {
		//Latent Variables
		double[][] newdelta = new double[numUsers][numComms];
		double[][] newtheta = new double[numComms][numTopics];
		double[][] newphi = new double[numTopics][numWords];
		double[][] newpsi = new double[numComms][numUsers];

		for (int u = 0; u < numUsers; u++) {
			for (int k = 0; k < numComms; k++) {
				newdelta[u][k] = (Muk[u][k] + Fuk[u][k] + nu[k]) / (Mu[u] + Fu[u] + nuSum);
			}
		}
		for (int k = 0; k < numComms; k++) {
			for (int z = 0; z < numTopics; z++) {
				newtheta[k][z] = (Mkz[k][z] + alpha[z]) / (Mk[k] + alphaSum);
			}
		}
		for (int z = 0; z < numTopics; z++) {
			for (int w = 0; w < numWords; w++) {
				newphi[z][w] = (Nzw[z][w] + beta[w]) / (Nz[z] + betaSum);
			}
		}
		for (int k = 0; k < numComms; k++) {
			for (int u = 0; u < numUsers; u++) {
				newpsi[k][u] = (Fkv[k][u] + my[u]) / (Fk[k] + mySum); 
			}
		}
		//Durchschnitt aktualisieren
		for (int u = 0; u < numUsers; u++) {
			for (int k = 0; k < numComms; k++) {
				delta[u][k] = delta[u][k] * (count - 1) / (count) + newdelta[u][k] / count;
			}
		}
		for (int k = 0; k < numComms; k++) {
			for (int z = 0; z < numTopics; z++) {
				theta[k][z] = theta[k][z] * (count - 1) / (count) + newtheta[k][z] / count;
			}
		}
		for (int z = 0; z < numTopics; z++) {
			for (int w = 0; w < numWords; w++) {
				phi[z][w] = phi[z][w] * (count - 1) / (count) + newphi[z][w] / count;
			}
		}
		for (int k = 0; k < numComms; k++) {
			for (int u = 0; u < numUsers; u++) {
				psi[k][u] = psi[k][u] * (count - 1) / (count) + newpsi[k][u] / count;
			}
		}
		count++;
	}

	public static void save() {
		try {
			FileWriter fw;
			fw = new FileWriter("3latent.json");
			Gson gson = new Gson();
			LatentVariables variablesToSave = new LatentVariables(delta, theta, phi, psi);
			gson.toJson(variablesToSave, fw);
			fw.close();
			System.out.println("ENDE!!!");
		} catch (java.io.IOException e) {
			System.out.println("Saving failed");
		}
	}

	/**
	 * Perplexit�tsberechnung basierend auf Test- und Trainingsdatensatz
	 * @param users
	 * @param delta
	 * @param theta
	 * @param phi
	 * @param psi
	 * @return
	 */
	public static double getPerplexity(ArrayList<User> users, double delta[][], double theta[][], double phi[][], double psi[][]) {
		double perplexity = 0;
		double docProb, prod, sumOverDocs = 0;
		int numWords = 0;

		for (int u = 0; u < users.size(); u++) {
			User user = users.get(u);
			for (int d = 0; d < user.docs.size(); d++) {
				Document doc = user.docs.get(d);
				docProb = 0; //Wahrscheinlichkeit f�r ein Dokument
				int counter = 0;
				double[] exponents = new double[numComms * numTopics]; //Exponenten f�r logSumExp ()

				for (int z = 0; z < numTopics; z++) {
					for (int k = 0; k < numComms; k++) {
						prod = 0; //Wahrscheinlichkeit f�r ein letzten Term aus Gleichung 1 (das Produkt)
						for (int w = 0; w < users.get(u).docs.get(d).words.size(); w++) {
							Word word = doc.words.get(w);
							prod += Math.log(phi[z][word.id]) * (word.freq);
						}
						exponents[counter] = Math.log(delta[u][k]) + Math.log(theta[k][z]) + prod; //Wahrsch. f�r ein Doc mit bestimmter Comm und Topic
						counter++;
					}
				}
				docProb = logSumOfExponentials(exponents); //Summe �ber alle DocProbabilitys bzgl auf alle Comms und Topics
				for (int w = 0; w < users.get(u).docs.get(d).words.size(); w++) //gesamtzahl W�rter aller Dokumente summieren
					numWords += users.get(u).docs.get(d).words.get(w).freq;
				sumOverDocs += docProb; //Summe �ber Bruch bei der Perplexit�tsformel
			}
		}
		perplexity = Math.exp(-(sumOverDocs / numWords));
		return perplexity;
	}

	/**
	 * logSumExp-N�herung
	 * @param xs
	 * @return
	 */
	public static double logSumOfExponentials(double[] xs) {
		if (xs.length == 1)
			return xs[0];
		double max = max(xs);
		double sum = 0.0;
		for (int i = 0; i < xs.length; ++i)
			if (xs[i] != Double.NEGATIVE_INFINITY)
				sum += java.lang.Math.exp(xs[i] - max);
		return max + java.lang.Math.log(sum);
	}

	public static double max(double[] array) {
		double max = Double.NEGATIVE_INFINITY;
		for (double value : array) {
			if (max < value) {
				max = value;
			}
		}
		return max;
	}

	/**
	 * Berechnet Perplexität
	 * @param burn
	 * @param sample
	 * @param skip
	 * @throws NegativeProbabilityException
	 * @throws FileNotFoundException
	 * @throws IOException
	 */
	public static void runPerplexity(int burn, int sample, int skip) throws NegativeProbabilityException, FileNotFoundException, IOException {
		double[] perp = new double[sample];
		runSampling(burn, 0, 0);
		ArrayList<User> usersPerp = new ArrayList<User>();

		usersPerp = gson.fromJson(new FileReader("2test.json"), new TypeToken<ArrayList<User>>() {
		}.getType());

		for (int i = 0; i < sample; i++) {
			runSampling(0, 1, 0);
			perp[i] = getPerplexity(usersPerp, delta, theta, phi, psi);
			System.out.println("Sample: " + i + ",  Perplexity: " + perp[i]);
			runSampling(skip, 0, 0);
		}
		System.out.println("ENDE:");

		for (int i = 0; i < sample; i++){
			System.out.println(Math.round(perp[i]));
			}
	}

	/**
	 * Erstellt Pfad
	 * @param fName Pfadname
	 * @throws IOException
	 */	
	public static void makeOutputDirectories(String fName) throws IOException{
		File f = new File("output"); //Verzeichnis erstellen wenn noch nicht da
		if (! f.exists()) f.mkdir();
		f = new File("output/"+fName); 
		if (! f.exists()) f.mkdir();
		dName = "output/"+fName+"/";
		pName+= "-c"+numComms+"-t"+numTopics+"-b"+burnedSamples+"-cs"+collectedSamples+"-ss"+skippedSamples;

	}		
	
	public static String getFileN(String fromF) throws IOException {
		BufferedReader bfF = new BufferedReader(new FileReader(fromF));
		try {
			String f = bfF.readLine();
			bfF.close();
			if (f.split(" ").length >=1) dataName=f.split(" ")[0];
			if (f.split(" ").length >=2) pName=f.split(" ")[1];
			if (f.split(" ").length >=3)  p2=f.split(" ")[2];
			return f;
		}
		catch (IOException E) {
			return "";
		}
	}	
}
