#!/usr/bin/python
# -*- coding: utf-8 -*-

# --------------------------------------------------
# File Name: exp_1i.py
# Location: 
# Purpose:
# Creation Date: 29-10-2017
# Last Modified: Wed, Nov 15, 2017  5:54:01 PM
# Author(s): Mike Stout 
# Copyright 2017 The Author(s) All Rights Reserved
# Credits: 
# --------------------------------------------------

import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 15) 
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
pd.set_option('chop_threshold', 0)
pd.set_option('precision', 4)
pd.set_option('display.colheader_justify','right')
pd.set_option('display.max_colwidth', -1) 
pd.set_option('display.max_colwidth', 64) 

import matplotlib.pyplot as plt 
plt.style.use('ggplot')


df=pd.read_csv("df.csv")

#print df # .info()

qc = df[(df.Edtn == 'Q') & (df.Subset == 'Common')]
fc = df[(df.Edtn == 'F') & (df.Subset == 'Common')]
columns = ['Act','Scene','Speaker', 'sp#','#chars', 'txt']
qc = qc[columns]
fc = fc[columns]


res = qc.merge(fc, on=['Act','Scene','Speaker']) 

# N.B. Speach num must be at least fairly close ....
# this allows lots of mis matches ....

res = res[ (abs(res['sp#_x'] - res['sp#_y']) < 6)] 

print qc.shape, fc.shape, res.shape


#-------------------------------------------------------------------------------
# https://github.com/aflc/editdistance

import editdistance

def calcEditDistance(r): 
    return editdistance.eval(r.txt_x, r.txt_y)


res['ed'] = res.apply(lambda row: calcEditDistance(row), axis=1)

# N.B. For speeches to match the e.d. must be lowish ... 
# so remove any mismatches by this criterion ....
res = res[ (res.ed < res['#chars_x']/1.5 ) ] 

print res.shape

#-------------------------------------------------------------------------------
from mlpy import dtw_subsequence, dtw_std

lc = 'abcdefghijklmnopqrstuvwxyz'
lc_lut = [ (a, i) for i,a in enumerate(lc) ] 
uc_lut = [ (a.upper(), i) for a,i in lc_lut ]
lut = dict( lc_lut + uc_lut  ) 

#def encode(c): return ord(c.lower())*200

def enc(c):
    try: val = lut[c.lower()]
    except: val = ord(c)  * 2.
    return val

def encode(s):
    return np.array(map(enc, s))
    #return np.array(map(float, map(ord, s)))

def splits(n, f, xs):
    return [ f+":\t" + xs[i:i+n] for i in range(0, len(xs), n)]

def fix(xs, i, path):
    if i>0 and path[i]==path[i-1]: return '_' # gapChar
    else: return xs[path[i]]


def recover(xs, path):
    return ''.join([ fix(xs, i, path) for i in xrange(len(path))]).replace("\n"," ")

def dtw(r): 
    q = r.txt_x
    f = r.txt_y
    q_ = encode(q)
    f_ = encode(f)
    dist, cost, path = dtw_subsequence(q_, f_) 
    q_ = recover(q, path[0])
    f_ = recover(f, path[1])
    al = (q_,f_)
    return dist, al

res['dtw'] = res.apply(lambda row: dtw(row)[0], axis=1)
res['alignment'] = res.apply(lambda row: dtw(row)[1], axis=1)

#-------------------------------------------------------------------------------
from scipy.stats import entropy # == KL Divergence
import numpy as np

def jsd(r):
    x,y = r.alignment # ... the aligned texts

    P = encode(x)
    Q = encode(y)
    M = 0.5 * (P + Q)
    return 0.5 * (entropy(P, M) + entropy(Q, M))


res['jsd'] = res.apply(lambda row: jsd(row), axis=1)

#-------------------------------------------------------------------------------
from sklearn.metrics import normalized_mutual_info_score

def nmi(r):
    x,y = r.alignment # dtw(r)[1] # ... the aligned texts

    q = encode(x)
    f = encode(y)

    return normalized_mutual_info_score(q,f)

res['nmi'] = res.apply(lambda row: nmi(row), axis=1)

# N.B. Also for speeches to match the m.i.  must be highish ... 
# so remove mismatches ....

res = res[ (res.nmi > .7) ]
print res.shape
#-------------------------------------------------------------------------------
# Main .... 

# ok so we think we have matching speeches in Q and F 
# so we can now do pairwise metrics ....

res.columns = ['Act','Scene','Speaker','sp#', '#chars_q', "q_txt", 'sp#_', '#chars_f', 'f_txt', 'ed', 'dtw','alignment', 'jsd','nmi']

print res
res.to_csv('res.csv', header=True, index=False, encoding='utf-8')

plotTitles = "Edit Distance", "DTW Distance", "Jensen Shannon Divergence", "Normalised Mutual Information Score"
columns = "ed", "dtw", "jsd", "nmi"

metrics = zip(plotTitles, columns)


import myPlot

for title, metric in metrics:

    zero = 0
    if metric == 'nmi': zero = .7

    # Calc table w margings ...
    pt = res.pivot_table(
        values=metric
        , columns=['Scene']
        , index=['Speaker']
        , aggfunc=np.mean
        #, fill_value=0 
        , margins=True
        )
    pt = pt.T
    #print pt

    if metric == 'jsd': 
        pt = pt *100
        

    myPlot.hmap(title, title, "Speaker", "Scene", pt, 'YlOrRd', zero)

    myPlot.outputTable(title, pt)



#-------------------------------------------------------------------------------
r_max = res['#chars_f'].argmin()

# Find sp w worst m.i. ....

r_max = res['nmi'].argmin()
r = res.ix[r_max]
print r
x,y =  r.alignment
#print x +'\n'+ y

n = 100 
aa = splits(n, "Q", x)
bb = splits(n, "F", y)

zz = [ "\n"+a+"\n"+b for a,b in zip(aa,bb) ]

for z in zz: print z

#-------------------------------------------------------------------------------
# Pairs plots ... 
# for info metrics .. 

import seaborn as sns 
sns.set(style="ticks", color_codes=True)
plt.xticks(rotation=90)

sns.pairplot(res[['Speaker','ed','dtw','jsd','nmi']], hue="Speaker", palette="husl")
filename = "Speaker by Scene Q F Information Metrics".replace(' ', '_')
plt.savefig('data/'+filename+'.png')
#plt.show()