#!/usr/bin/python
# -*- coding: utf-8 -*-

# --------------------------------------------------
# File Name: exp_1i.py
# Location: 
# Purpose:
# Creation Date: 29-10-2017
# Last Modified: Wed, Nov 15, 2017  6:19:52 PM
# Author(s): Mike Stout 
# Copyright 2017 The Author(s) All Rights Reserved
# Credits: 
# --------------------------------------------------

import numpy as np
import pandas as pd
import re
pd.set_option('display.max_rows', 15) 
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 100)
pd.set_option('chop_threshold', 0)
pd.set_option('precision', 4)
pd.set_option('display.colheader_justify','right')
pd.set_option('display.max_colwidth', -1) 
pd.set_option('display.max_colwidth', 32) 

import matplotlib.pyplot as plt 
plt.style.use('ggplot')

df=pd.read_csv("df.csv")
print df


f_words=pd.read_csv("CK_200_function_words.csv")
f_words.columns = ['word']
def fix_pos(r):
    w = r.word
    return w.split(' ')[0]

f_words['word'] = f_words.apply(lambda row: fix_pos(row), axis=1)


#print df # .info()
#df = df[ ((df.Subset == 'Only')) |  ((df.Edtn == 'Q') & (df.Subset == 'Common')) ] 
df = df[ (df.Subset != 'Full-a')]
print df


#df = df[ (df.Subset == 'Only') ] 
df['Text'] = df['Edition'] + ' ' + df['Subset']

columns = ['Text', 'txt']
df = df[columns]
print df


def words(s): 
    #return s.split(' ')
    return filter(lambda x: x!='', re.split('[ ,;:.\>\<?!#@\n\t\']', s))


def calcFreq(w, r): 
    s = r.txt
    n = float(len(s))
    ws = words(s)
    f_ws = filter(lambda x: x==w, ws)
    return len(f_ws)/n

print f_words

for w in f_words.word:
    df[w] = df.apply(lambda row: calcFreq(w, row), axis=1)




df_ = df[[column for column in df.columns if column not in columns]]

# PCA ...
import sklearn.decomposition
pca = sklearn.decomposition.PCA(n_components = 3)
pca.fit(df_)
projection = pca.transform(df_)
x = projection[:,0]
y = projection[:,1]
z = projection[:,2]

# Groups ...
labels = df['Text'] 
df = pd.DataFrame(dict(x=x, y=y, z=z, label=labels))
groups = df.groupby('label')


# 3D Scatter Plot ...
from mpl_toolkits.mplot3d import Axes3D

plt3d = plt.figure().gca(projection='3d')
# Center the axes ...
plt3d.autoscale(enable=False,axis='both') # ... needed to change the Z-axis
o_x = 0
o_y = 0
o_z = 0
delta = .1
plt3d.set_xbound(o_x-delta, o_x+delta)
plt3d.set_ybound(o_y-delta, o_y+delta)
plt3d.set_zbound(o_z-delta, o_z+delta)
plt3d.set_xlabel('Pr Comp 1')
plt3d.set_ylabel('Pr Comp 2')
plt3d.set_zlabel('Pr Comp 3')


markers = 'o', '^', 'D', '*'
import matplotlib.cm as cm
colors = cm.rainbow(np.linspace(0, 1, len(markers)))
colors = 'magenta', 'b', 'r', 'g'


labels=[]
for i, (name,group) in enumerate(groups):
   labels.append(name)

   # Calculate centroids .... 
   c_x = np.mean(group.x)
   c_y = np.mean(group.y)
   c_z = np.mean(group.z)
   print c_x, c_y, c_z

   plt3d.scatter(c_x, c_y, c_z, marker=markers[i], color=colors[i], s=500) 
   plt3d.scatter(group.x, group.y, group.z , marker=markers[i] , color=colors[i], s=20, alpha=.2)

# Fake a legend for 3d scatter ....
proxy = []
for c,m in zip(colors, markers):
    proxy.append(plt.Line2D([0],[0], linestyle="none", color=c, marker=m))
plt3d.legend(proxy, labels, numpoints = 1)
plt.suptitle("PCA of 200 Function Words in Lr Speeches")

plt.show()


exit()

#------------------------------------

import seaborn as sns
import matplotlib.pyplot as plt
#%matplotlib inline
sns.set(context = "paper", font = "monospace")

labels = df['Text'] 
print labels
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
fig, ax = plt.subplots()
ax.margins(0.05) # Optional, just adds 5% padding to the autoscaling
for name, group in groups:
    ax.plot(group.x, group.y, marker='.', linestyle='', ms=12, label=name)
ax.legend()

'''
from scipy.cluster.vq import kmeans2
centroids, ks = kmeans2(np.array(df), 2, 10)
print centroids
exit()

colors = ['r', 'g'] # , 'b']
plt.scatter(*df.T, c=np.choose(ks, colors))
plt.scatter(*centroids.T, c=colors, marker='v')
'''

plt.show()
fig.savefig('tmp.png')





exit()

'''

def approx(r):
    return len(r['txt']) /10 
    

#qc['fuzzymatch'] = qc.apply(lambda r: approx(r), axis=1) 
#fc['fuzzymatch'] = fc.apply(lambda r: approx(r), axis=1) 
'''

res = qc.merge(fc, on=['Act','Scene','Speaker']) #  ,'fuzzymatch'])

res = res[ (abs(res['sp#_x'] - res['sp#_y']) < 6)]

print qc.shape, fc.shape, res.shape


#-------------------------------------------------------------------------------
# https://github.com/aflc/editdistance

import editdistance

def calcEditDistance(r): 
    return editdistance.eval(r.txt_x, r.txt_y)

res['ed'] = res.apply(lambda row: calcEditDistance(row), axis=1)

res = res[ (res.ed < res['#chars_x']/1.5 ) ] 
print res.shape

#-------------------------------------------------------------------------------
from mlpy import dtw_subsequence, dtw_std

lc = 'abcdefghijklmnopqrstuvwxyz'
lc_lut = [ (a, i) for i,a in enumerate(lc) ] 
uc_lut = [ (a.upper(), i) for a,i in lc_lut ]
lut = dict( lc_lut + uc_lut  ) 

#def encode(c): return ord(c.lower())*200

def enc(c):
    try: val = lut[c.lower()]
    except: val = ord(c)  * 2.
    return val

def encode(s):
    return np.array(map(enc, s))
    #return np.array(map(float, map(ord, s)))

def splits(n, f, xs):
    return [ f+":\t" + xs[i:i+n] for i in range(0, len(xs), n)]

def fix(xs, i, path):
    if i>0 and path[i]==path[i-1]: return '_' # gapChar
    else: return xs[path[i]]


def recover(xs, path):
    return ''.join([ fix(xs, i, path) for i in xrange(len(path))]).replace("\n"," ")

def dtw(r): 
    q = r.txt_x
    f = r.txt_y
    q_ = encode(q)
    f_ = encode(f)
    dist, cost, path = dtw_subsequence(q_, f_) 
    q_ = recover(q, path[0])
    f_ = recover(f, path[1])
    al = (q_,f_)
    return dist, al

res['dtw'] = res.apply(lambda row: dtw(row)[0], axis=1)
res['alignment'] = res.apply(lambda row: dtw(row)[1], axis=1)

#-------------------------------------------------------------------------------
from scipy.stats import entropy # == KL Divergence
import numpy as np

def jsd(r):
    x,y = r.alignment # ... the aligned texts

    P = encode(x)
    Q = encode(y)
    M = 0.5 * (P + Q)
    return 0.5 * (entropy(P, M) + entropy(Q, M))


res['jsd'] = res.apply(lambda row: jsd(row), axis=1)

#-------------------------------------------------------------------------------
from sklearn.metrics import normalized_mutual_info_score

def nmi(r):
    x,y = r.alignment # dtw(r)[1] # ... the aligned texts

    q = encode(x)
    f = encode(y)

    return normalized_mutual_info_score(q,f)

res['nmi'] = res.apply(lambda row: nmi(row), axis=1)

res = res[ (res.nmi > .7) ]
print res.shape
#-------------------------------------------------------------------------------

res.columns = ['Act','Scene','Speaker','sp#', '#chars_q', "q_txt", 'sp#_', '#chars_f', 'f_txt', 'ed', 'dtw','alignment', 'jsd','nmi']

print res
res.to_csv('res.csv', header=True, index=False, encoding='utf-8')

plotTitles = "Edit Distance", "DTW Distance", "Jensen Shannon Divergence", "Normalised Mutual Information Score"
columns = "ed", "dtw", "jsd", "nmi"

metrics = zip(plotTitles, columns)



import myPlot
for title, metric in metrics:

    zero = 0
    if metric == 'nmi': zero = .7

    # Calc table w margings ...
    pt = res.pivot_table(
        values=metric
        , columns=['Scene']
        , index=['Speaker']
        , aggfunc=np.mean
        #, fill_value=0 
        , margins=True
        )
    pt = pt.T
    #print pt

    if metric == 'jsd': 
        pt = pt *100
        

    myPlot.hmap(title, title, "Speaker", "Scene", pt, 'YlOrRd', zero)

    myPlot.outputTable(title, pt)



r_max = res['#chars_f'].argmin()
r_max = res['nmi'].argmin()
r = res.ix[r_max]
print r
x,y =  r.alignment
#print x +'\n'+ y

n = 100 
aa = splits(n, "Q", x)
bb = splits(n, "F", y)

zz = [ "\n"+a+"\n"+b for a,b in zip(aa,bb) ]

for z in zz: print z




import seaborn as sns 
sns.set(style="ticks", color_codes=True)
plt.xticks(rotation=90)

sns.pairplot(res[['Speaker','ed','dtw','jsd','nmi']], hue="Speaker", palette="husl")
#plt.show()
filename = "Speaker by Scene Q F Information Metrics".replace(' ', '_')
plt.savefig('data/'+filename+'.png')