#!/usr/bin/python
# -*- coding: utf-8 -*-

# --------------------------------------------------
# File Name: exp_2a.py
# Location: 
# Purpose:
# Creation Date: 29-10-2017
# Last Modified: Fri, Dec  1, 2017  3:55:22 PM
# Author(s): Mike Stout 
# Copyright 2017 The Author(s) All Rights Reserved
# Credits: 
# --------------------------------------------------

import numpy as np
import pandas as pd
import re
pd.set_option('display.max_rows', 15) 
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 100)
pd.set_option('chop_threshold', 0)
pd.set_option('precision', 4)
pd.set_option('display.colheader_justify','right')
pd.set_option('display.max_colwidth', -1) 
pd.set_option('display.max_colwidth', 32) 

import matplotlib.pyplot as plt 
plt.style.use('ggplot')


def outputTable(filename, df):
    df = pd.DataFrame(df)
    filename = filename.replace(' ', '_')
    html = df.to_html()
    with open('data/'+filename+'.txt', 'w') as text_file:
        text_file.write(html)

#-------------------------------------------------------------------------------
# https://github.com/aflc/editdistance

import editdistance

ed = editdistance.eval

#-------------------------------------------------------------------------------

from scipy.stats import entropy # == KL Divergence

#-------------------------------------------------------------------------------
from mlpy import dtw_subsequence, dtw_std

# Encode Lower and upercase letters to same values ...
lc = 'abcdefghijklmnopqrstuvwxyz'
lc_lut = [ (a, i) for i,a in enumerate(lc) ] 
uc_lut = [ (a.upper(), i) for a,i in lc_lut ]
lut = dict( lc_lut + uc_lut  ) 

#def encode(c): return ord(c.lower())*200

def enc(c):
    try: val = lut[c.lower()]
    except: val = ord(c)  * 2.
    return val

def encode(s):
    return np.array(map(enc, s))
    #return np.array(map(float, map(ord, s)))

def splits(n, f, xs):
    return [ f+":\t" + xs[i:i+n] for i in range(0, len(xs), n)]

def fix(xs, i, path):
    if i>0 and path[i]==path[i-1]: return '_' # gapChar
    else: return xs[path[i]]


def recover(xs, path):
    return ''.join([ fix(xs, i, path) for i in xrange(len(path))]).replace("\n"," ")

def dtw(x,y): 
    x_ = encode(x)
    y_ = encode(y)
    dist, cost, path = dtw_subsequence(x_, y_) 
    x = recover(x, path[0])
    y = recover(y, path[1])
    al = (x,y)
    return dist, al

#-------------------------------------------------------------------------------
from scipy.stats import entropy # == KL Divergence
import numpy as np

def jsd(x,y):

    P = encode(x)
    Q = encode(y)

    M = 0.5 * (P + Q)
    return 0.5 * (entropy(P, M) + entropy(Q, M))


#-------------------------------------------------------------------------------
from sklearn.metrics import normalized_mutual_info_score

def nmi(x,y):
    x = encode(x)
    y = encode(y)

    return normalized_mutual_info_score(x,y)

#-------------------------------------------------------------------------------

def words(s): 
    #return s.split(' ')
    return filter(lambda x: x!='', re.split('[ ,;:.\>\<?!#@\n\t\']', s))



#f_words=pd.read_csv("CK_200_function_words.csv")
#f_words.columns = ['word']

df=pd.read_csv("df.csv")
print df
df = df[ (df.Subset == 'Full-a')]
print df
qf = df[ (df.Edtn == 'Q')]
ff = df[ (df.Edtn == 'F')]
print qf
print ff

def proc(r): 
    #return r.spkr + " " + r.txt
    return r.txt

def labels(r): 
    return r.Scene + " " + str(r['sp#']) + " " + r.spkr

'''
def mkWindow(i,xs):
    print xs
    w = xs[i:i+wSize]
    s = ' '.join(w)
    #print s
    #exit()
    return s
'''

#wSize = 200
def window(df):

    samples = df[['Scene','spkr','txt']][:30]
    samples['sample']  = df.apply(lambda row: proc(row), axis=1)
    print samples[:10]

    # NB use speeches as the windows .... 
    windows = samples['sample']
    windows = np.array(windows)

    samples['labs']  = df.apply(lambda row: labels(row), axis=1)
    labs = samples['labs']

    return labs, windows

q_labs, q_ws = window(qf)
f_labs, f_ws = window(ff)
q_labs = np.array(q_labs)
f_labs = np.array(f_labs)



res = []
res_a = []

k=0
for q in q_ws:

    print k
    k+=1

    pairs = []
    for f in f_ws:
       if 1:#  qs==fs:
        
            edist = ed(q,f)
            dist, al = dtw(q,f)

            # For nmi and jsd the matrix must by NxN
            # .. so strings (concatenations of 2000 words) must be aligned ... 
            q_,f_ = al 
            nmi_score = nmi(q_,f_)
            jsd_score = jsd(q_,f_)
            #print q_
            #print f_

            #dat = nmi_score, 1/float(edist+1), jsd_score, 1/float(dist+1)
            dat = nmi_score, edist, jsd_score, dist
            #print dat
            pairs.append(dat)

    pairs = np.array(pairs).T
    mins = map(np.argmin, pairs)[1:]

    # .. for nmi need to find max not min
    maxs = [map(np.argmax, pairs)[0]] 

    posns = maxs + mins

    scores = []
    for i,p in enumerate(posns):
        scores.append(pairs[i][p])

    res.append(zip(posns, scores))
    res_a.append(pairs) 

def selectMetric(a,i):
    return np.array([ ys[i] for ys in a ])

res = np.array(res)
res_a = np.array(res_a)
print res_a
print res_a.shape
#res = pd.DataFrame([res])
#res.to_csv('res_scores.csv', header=True, index=False, encoding='utf-8')


metricNames = "Normalised Mutual Information Score", "Edit Distance", "Jensen Shannon Divergence",  "DTW Distance"
metrics = 'nmi ed jsd dtw'.split(' ')

from myPlot import hmap    

res = []
for i, metric in enumerate(metrics):
    arr = selectMetric(res_a, i)
    print metric, arr.shape
    title = metricNames[i] 

    hmap(title, title, "Q", "F", arr, q_labs, f_labs, 'YlOrRd', 0)
    outputTable(title, arr)