#!/usr/bin/python
# -*- coding: utf-8 -*-

# --------------------------------------------------
# File Name: exp_2a.py
# Location: 
# Purpose:
# Creation Date: 29-10-2017
# Last Modified: Wed, Jan 31, 2018  9:18:26 PM
# Author(s): Mike Stout 
# Copyright 2017 The Author(s) All Rights Reserved
# Credits: 
# --------------------------------------------------

import sys
import numpy as np

#----------------------------------------------------------------
# https://github.com/aflc/editdistance

import editdistance

def ed(q,f):
    k = 1. # Not Normalised
    # k = float(len(q)+len(f)) # Normalised
    return editdistance.eval(q,f) / k

#----------------------------------------------------------------

from mlpy import dtw_subsequence, dtw_std

# Encode Lower and upercase letters to same values ...
lc = 'abcdefghijklmnopqrstuvwxyz'
lc_lut = [ (a, i) for i,a in enumerate(lc) ] 
uc_lut = [ (a.upper(), i) for a,i in lc_lut ]
lut = dict( lc_lut + uc_lut  ) 

#def encode(c): return ord(c.lower())*200

def encChar(c):
    try: val = lut[c.lower()]
    except: val = ord(c)  #* 2
    return val

def isOK(c):
    return c in "aeiouAEIOUst"

def encString(s):
    #if s=='of': s='from'
    s = filter(isOK, s) ## NB consider only vowels...
    s_ = map(encChar, s)#[::-1]
    #s_ = map(ord, s)
    val = sum(s_)
    #val = float( ''.join(map(str, s_)) + ".0")
    if val==0: return 0.
    else: return 1.0/val



def encode(s):
    return np.array(map(encString, s))
    #return np.array(map(float, map(ord, s)))

'''
def splits(n, f, xs):
    return [ f+":\t" + xs[i:i+n] for i in range(0, len(xs), n)]

def fix(xs, i, path):
    if i+1<len(path) and path[i]!=path[i+1]: return xs[path[i]]
    else: return "_" # xs[path[i]]


def recover(xs, ys, xp, yp):

    xl,yl,flip = len(xp),len(yp),False
    k = xl
    if xl<yl: xl,yl,Flip = yl,xl.True

    res = []
    xoff = 0, 0
    for i in xrange(yl):
        while xp[i+xoff]==yp[i+yoff]: res.append((xp[i+xoff],xp[i+yoff]))
        while xp[i+xoff]!=yp[i+yoff]: xoff+=1
             
    

    res = []
    for i in xrange(k):

        x = xs[i]
        y = ys[i]
        
        if xs[i]==ys[i]: = x,y # res.append((x,y))
        if xs[i+1]==xs[i]: 
            for xy in recover(xs, ys[i+1:], xs, yp[i+1:]):
                
    if ys[i+1]==ys[i]: return recover(xs, ys[i+1:], xs, yp[i+1:])
    re
        

    path_ = (path + path[:1]) # [::-1]
    xs_ = (xs + xs[:1])
    
    return ([ fix(xs_, i, path_) for i in xrange(len(path_))])[::-1]
'''


def rl_enc(input_string):
    count = 1
    prev = ''
    lst = []
    for item in input_string:
        if item != prev:
            if prev:
                entry = (prev,count)
                lst.append(entry)
                #print lst
            count = 1
            prev = item
        else:
            count += 1
    else:
        entry = (item,count)
        lst.append(entry)
    return lst

def rl_dec(lst):
    xs = []
    for i,(item,n) in enumerate(lst):
        if n==1: xs.append([item])
        else: xs.append([item] + [str((i,n-1))])
    return [item for sublist in xs for item in sublist]

''' 
def rl_dec(lst):
    q = ""
    for character, count in lst:
        q += character * count
    return q
'''

def recover(xs):
    return rl_dec( rl_enc(xs) )

def decode(xs,x):
    return [ x[i] for i in xs]



def fixLen((i,(x,y))):
    if x=='_': x = str(i) # *(len(y)) # -len(x))
    if y=='_': y = str(i) # *(len(x)) # -len(y))
    return x,y

def dtw(a,b):

    #a = a[::-1]
    #b = b[::-1]

    x,y,flip = a,b, False
    if len(a) > len(b): x,y,flip = b,a,True

    xe = encode(x[::-1])
    ye = encode(y[::-1])
    #dist, cost, path = dtw_subsequence(xe, ye) 
    dist, cost, path = dtw_std(xe, ye, dist_only=False) 
    p0, p1 = path
    xa = recover(decode(p0,x))
    ya = recover(decode(p1,y))
    xa,ya = zip(*map(fixLen,  enumerate(zip(xa,ya))))


    #xa = xa[::-1]
    #ya = ya[::-1]
   
    al = xa,ya 
    if flip: al = ya,xa

    
    return dist, al

#----------------------------------------------------------------
from scipy.stats import entropy # == KL Divergence

def jsd(x,y):

    P = np.array(map(float, encode(x)))
    Q = np.array(map(float, encode(y)))
 
    M = 0.5 * (P + Q)

    return 0.5 * (entropy(P, M) + entropy(Q, M))


#----------------------------------------------------------------
from sklearn.metrics import normalized_mutual_info_score

def nmi(x,y):
    x = encode(x)
    y = encode(y)

    return normalized_mutual_info_score(x,y)