#!/usr/bin/python
# -*- coding: utf-8 -*-

# --------------------------------------------------
# File Name: exp_1i.py
# Location: 
# Purpose:
# Creation Date: 29-10-2017
# Last Modified: Wed, Nov 15, 2017  6:28:25 PM
# Author(s): Mike Stout 
# Copyright 2017 The Author(s) All Rights Reserved
# Credits: 
# --------------------------------------------------

import numpy as np
import pandas as pd
import re
pd.set_option('display.max_rows', 15) 
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 100)
pd.set_option('chop_threshold', 0)
pd.set_option('precision', 4)
pd.set_option('display.colheader_justify','right')
pd.set_option('display.max_colwidth', -1) 
pd.set_option('display.max_colwidth', 32) 

import matplotlib.pyplot as plt 
plt.style.use('ggplot')



#-------------------------------------------------------------------------------
# https://github.com/aflc/editdistance

import editdistance

ed = editdistance.eval

#-------------------------------------------------------------------------------

from scipy.stats import entropy # == KL Divergence


#-------------------------------------------------------------------------------
from mlpy import dtw_subsequence, dtw_std

# Enclde Lower and upercase letters to same values ...
lc = 'abcdefghijklmnopqrstuvwxyz'
lc_lut = [ (a, i) for i,a in enumerate(lc) ] 
uc_lut = [ (a.upper(), i) for a,i in lc_lut ]
lut = dict( lc_lut + uc_lut  ) 

#def encode(c): return ord(c.lower())*200

def enc(c):
    try: val = lut[c.lower()]
    except: val = ord(c)  * 2.
    return val

def encode(s):
    return np.array(map(enc, s))
    #return np.array(map(float, map(ord, s)))

def splits(n, f, xs):
    return [ f+":\t" + xs[i:i+n] for i in range(0, len(xs), n)]

def fix(xs, i, path):
    if i>0 and path[i]==path[i-1]: return '_' # gapChar
    else: return xs[path[i]]


def recover(xs, path):
    return ''.join([ fix(xs, i, path) for i in xrange(len(path))]).replace("\n"," ")

def dtw(x,y): 
    x_ = encode(x)
    y_ = encode(y)
    dist, cost, path = dtw_subsequence(x_, y_) 
    x = recover(x, path[0])
    y = recover(y, path[1])
    al = (x,y)
    return dist, al

#-------------------------------------------------------------------------------
from scipy.stats import entropy # == KL Divergence
import numpy as np

def jsd(x,y):

    P = encode(x)
    Q = encode(y)

    M = 0.5 * (P + Q)
    return 0.5 * (entropy(P, M) + entropy(Q, M))


#-------------------------------------------------------------------------------
from sklearn.metrics import normalized_mutual_info_score

def nmi(x,y):
    x = encode(x)
    y = encode(y)

    return normalized_mutual_info_score(x,y)

#-------------------------------------------------------------------------------

def words(s): 
    #return s.split(' ')
    return filter(lambda x: x!='', re.split('[ ,;:.\>\<?!#@\n\t\']', s))



f_words=pd.read_csv("CK_200_function_words.csv")
f_words.columns = ['word']

df=pd.read_csv("df.csv")
df = df[ (df.Subset == 'Full-a')]
qf = df[ (df.Edtn == 'Q')]
ff = df[ (df.Edtn == 'F')]

def proc(r): 
    return r.spkr + " " + r.txt


def mkWindow(i,xs):
    w = xs[i:i+wSize]
    s = ' '.join(w)
    return s

wSize = 200
def window(df):

    samples = df.apply(lambda row: proc(row), axis=1)
    #print samples

    txt_full = ''
    for sample in samples:
        txt_full += (' ' + sample)

    ws = words(txt_full)

    windows=[]
    print ed, "#words: ", len(ws)
    for i in xrange(100): # len(ws)-wSize):
        windows.append(mkWindow(i,ws))

    windows = np.array(windows)
    #print windows
    return windows

q_ws = window(qf)
f_ws = window(ff)
print q_ws.shape, f_ws.shape

res = []
res_a = []

k=0
for q in q_ws:


    print k
    k+=1


    pairs = []
    for f in f_ws:

        edist = ed(q,f)
        dist, al = dtw(q,f)

        # For nmi and jsd the matrix must by NxN
        # .. so strings (concatenations of 2000 words) must be aligned ... 
        q_,f_ = al 
        nmi_score = nmi(q_,f_)
        jsd_score = jsd(q_,f_)
        #print q_
        #print f_

        dat = nmi_score, edist, jsd_score, dist
        pairs.append(dat)

    pairs = np.array(pairs).T
    mins = map(np.argmin, pairs)[1:]

    # .. for nmi need to find max not min
    maxs = [map(np.argmax, pairs)[0]] 

    posns = maxs + mins

    scores = []
    for i,p in enumerate(posns):
        scores.append(pairs[i][p])

    res.append(zip(posns, scores))
    res_a.append(pairs) 

def selectMetric(a,i):
    return np.array([ ys[i] for ys in a ])

res = np.array(res)
res_a = np.array(res_a)
print res_a
print res_a.shape
#res = pd.DataFrame([res])
#res.to_csv('res_scores.csv', header=True, index=False, encoding='utf-8')


metricNames = "Normalised Mutual Information Score", "Edit Distance", "Jensen Shannon Divergence",  "DTW Distance"
plotTitles = [ name + " Map" for name in metricNames ] 

from myPlot import hmap    

for i, title in enumerate(plotTitles):
    arr = selectMetric(res_a, i)
    print title, arr.shape
    hmap(title, title, "Q", "F", arr, 'YlOrRd', 0)