#!/usr/bin/python
# -*- coding: utf-8 -*-

# --------------------------------------------------
# File Name: exp_2a.py
# Location: 
# Purpose:
# Creation Date: 29-10-2017
# Last Modified: Wed, Jan 31, 2018  9:44:26 PM
# Author(s): Mike Stout 
# Copyright 2017 The Author(s) All Rights Reserved
# Credits: 
# --------------------------------------------------

import sys
import numpy as np
import pandas as pd
import re

pd.set_option('display.max_rows', 15) 
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 100)
pd.set_option('chop_threshold', 0)
pd.set_option('precision', 4)
pd.set_option('display.colheader_justify','right')
pd.set_option('display.max_colwidth', -1) 
pd.set_option('display.max_colwidth', 32) 

from calcMetrics import * 

csvfile, divId, edThreshold, dtwThreshold = sys.argv[1:]

#------------------------------------------------------------------
# Read in the csv file and split out the required dataframes ...

def words(s): 
    #return s.split(' ')
    return filter(lambda x: x!='', re.split('[ ,;:.\>\<?!#@\n\t\']', s))

df=pd.read_csv(csvfile) 

if csvfile=="df.csv": # using GE additions marked data ...
    df = df[ (df.Subset == 'Full-a')] 
    qf = df[ (df.Edtn == 'Q')]
    ff = df[ (df.Edtn == 'F')]

else :
    df = df[ (df.Subset == 'Full')]
    qf = df[ (df.Edtn == 'Q1')]
    ff = df[ (df.Edtn == 'F')]

#qf = qf[:20]
#ff = ff[:20]

#------------------------------------------------------------------
# Extract speeches and labels for each from the q/f dataframes ...

def proc(r): 
    #return r.spkr + " " + r.txt
    return r.txt.strip()

def labels(r): 
    return r.Scene + " " + str(r['sp#']-1) + " " + r.spkr

'''
def mkWindow(i,xs):
    print xs
    w = xs[i:i+wSize]
    s = ' '.join(w)
    #print s
    #exit()
    return s


wSize = 200
'''

def window(df):

    samples = df[['Div', 'Scene','Speaker','txt']]
    
    samples = samples[ ( samples.Div==int(divId) )]
    #print samples
    samples['sample']  = df.apply(lambda row: proc(row), axis=1)
    #print samples[:10]

    # NB use speeches as the windows .... 
    windows = samples['sample']
    windows = np.array(windows)

    samples['labs']  = df.apply(lambda row: labels(row), axis=1)
    labs = samples['labs']

    return labs, windows

q_labs, q_ws = window(qf)
f_labs, f_ws = window(ff)
q_labs = np.array(q_labs)
f_labs = np.array(f_labs)

#------------------------------------------------------------------
# Compute Metrics of Q-F speech matrix ....

qfSpMatrix = []

from progressBar import progress

total = len(q_ws)

k=0
for q in q_ws:

    progress(k, total, status='')
    k+=1

    pairs = []
    for f in f_ws:

        edist = ed(q,f)

        dist, al = dtw(q.split(' '), f.split(' '))

        # For nmi and jsd the matrix must by NxN
        # ... so strings must be aligned ... 
        q_,f_ = al 
        
        nmi_score = nmi(q_,f_)
        jsd_score = jsd(q_,f_)
        dat = nmi_score, edist, jsd_score, dist
        pairs.append(dat)

    pairs = np.array(pairs).T
    mins = map(np.argmin, pairs)[1:]

    # .. for nmi need to find max not min
    maxs = [map(np.argmax, pairs)[0]] 

    posns = maxs + mins

    scoqfSpMatrix = []
    for i,p in enumerate(posns):
        scoqfSpMatrix.append(pairs[i][p])

    qfSpMatrix.append(pairs) 

qfSpMatrix = np.array(qfSpMatrix)

def selectMetric(a,i):
    return np.array([ np.round(ys[i],2) for ys in a ])

#metricNames = "Normalised Mutual Information Score", "Edit Distance", "Jensen Shannon Divergence" # ,  "DTW Distance"


metricId = 1 # ed
metrics = 'nmi ed jsd dtw'.split(' ')
metric = metrics[metricId]
qfSpEdMatrix = selectMetric(qfSpMatrix, metricId).T

#------------------------------------------------------------------
# Find the optimal matches in the Q-F speech edit dist matrix ....

def getPositions(qfSpEdMatrix):
    global metric

    print metric
    qfSpEdMatchList = []
    fPos = 0 
    qPos = 0 
    for r in qfSpEdMatrix:
        if metric=='nmi':   qPos,val = r.argmax(), np.max(r)
        else:               qPos,val = r.argmin(), np.min(r)
        qfSpEdMatchList.append((qPos,fPos,val))
        fPos +=1
    return qfSpEdMatchList

qfSpEdMatchList = getPositions(qfSpEdMatrix)

#------------------------------------------------------------------
# Recover speech pairs ...

qoff = 0
foff = 0

speechPairs = []
for i,x in enumerate(qfSpEdMatchList):
    q,f,v = x

    q_ = i-qoff
    f_ = i-foff

    q_txt = q_ws[i-qoff]
    f_txt = f_ws[i-foff]
    k = len(q_txt+f_txt)/2
    val = float(v) / k

    f_len = float(len(f_txt)) 
    q_len = float(len(q_txt))
    lenRatio = f_len/q_len
    if len(f_txt) > len(q_txt): lenRatio = 1/lenRatio
    #print len(f_txt), len(q_txt), val


    metadata = i,q,f,qoff,foff,q_,f_,v,val,lenRatio

    if q+qoff!=i: 
        qData = q_, "..." 
        qoff += 1
    else:
        qData = q_labs[q_], q_ws[q_]

    if f+foff!=i: 
        fData = f_, "..."
        foff += 1
    else:
        fData = f_labs[f_], f_ws[f_]

    data = val, lenRatio, metadata, qData, fData

    speechPairs.append(data)

#------------------------------------------------------------------
# Filter the additions only .....

def isAddition(speechPair):
    val, lenRatio, metadata, qData, fData = speechPair

    qLabel,qTxt = qData
    fLabel,fTxt = fData

    # Lets just consider speeches marked up as additions (by GE) ...
    pred = '>' in qTxt or '>' in fTxt
    #print pred

    pred = val>float(edThreshold) and (lenRatio > .1 and lenRatio < .9)


    return pred
    

speechPairs = filter(isAddition, speechPairs)

#------------------------------------------------------------------
# Output the results ....

def isMultiLine(q,f):
    numLinesQ = len(q.split('@'))
    numLinesF = len(f.split('@'))
    #print numLinesF,numLinesQ
    return numLinesQ > 2 or numLinesF > 2

def isAdd((x,y)):   
    return ''.join(set(x))=='_' or ''.join(set(y))=='_'
    
def fixAdd((x,y)):   
    if ''.join(set(x))=='_' : x = y
    if ''.join(set(y))=='_' : y = x
    return x,y
    
    

for speechPair in speechPairs:
    val, lenRatio, metadata, qData, fData = speechPair

    qLabel,qTxt = qData
    fLabel,fTxt = fData

    '''
    shorter = qTxt
    longer = fTxt
    if len(qTxt)>len(fTxt):
        shorter = fTxt
        longer = qTxt
    '''
    '''
    # Try to use DTW to find where the differences are ...
    qs = qTxt.split(' ')
    fs = fTxt.split(' ')

    da,al = dtw(qs,fs)

    qa,fa = al

    """ 
    qs = enumerate(qa)
    fs = enumerate(fa)
    xs = zip(qs,fs)
    #print xs
    """

    #qTxt_, fTxt_ = zip(*filter(isAdd, al_))
    #qTxt_, fTxt_ = zip(*map(fixAdd, al_))
    qa_ = ' '.join(qa)
    fa_ = ' '.join(fa)
    '''


    k = len(q_txt+f_txt)/2
    normDist = dist / float(k)

    #if qa == "...": normDist = 100.
    #if fa == "...": normDist = 100.

    #if len(qa)>len(fa): normDist = 1/normDist

    if 1: # normDist < float(dtwThreshold): #  and isMultiLine(qa,fa) :
        #print metadata # dist, normDist, metadata
        print qLabel, qTxt
        print fLabel, fTxt
       # print "========================== DTW Alignment:"
       # print  qa_
       # print  fa_
        print '-'*60