#!/usr/bin/python
# -*- coding: utf-8 -*-
# --------------------------------------------------
# File Name: exp_2a.py
# Location:
# Purpose:
# Creation Date: 29-10-2017
# Last Modified: Wed, Jan 31, 2018 9:44:26 PM
# Author(s): Mike Stout
# Copyright 2017 The Author(s) All Rights Reserved
# Credits:
# --------------------------------------------------
import sys
import numpy as np
import pandas as pd
import re
pd.set_option('display.max_rows', 15)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 100)
pd.set_option('chop_threshold', 0)
pd.set_option('precision', 4)
pd.set_option('display.colheader_justify','right')
pd.set_option('display.max_colwidth', -1)
pd.set_option('display.max_colwidth', 32)
from calcMetrics import *
csvfile, divId, edThreshold, dtwThreshold = sys.argv[1:]
#------------------------------------------------------------------
# Read in the csv file and split out the required dataframes ...
def words(s):
#return s.split(' ')
return filter(lambda x: x!='', re.split('[ ,;:.\>\<?!#@\n\t\']', s))
df=pd.read_csv(csvfile)
if csvfile=="df.csv": # using GE additions marked data ...
df = df[ (df.Subset == 'Full-a')]
qf = df[ (df.Edtn == 'Q')]
ff = df[ (df.Edtn == 'F')]
else :
df = df[ (df.Subset == 'Full')]
qf = df[ (df.Edtn == 'Q1')]
ff = df[ (df.Edtn == 'F')]
#qf = qf[:20]
#ff = ff[:20]
#------------------------------------------------------------------
# Extract speeches and labels for each from the q/f dataframes ...
def proc(r):
#return r.spkr + " " + r.txt
return r.txt.strip()
def labels(r):
return r.Scene + " " + str(r['sp#']-1) + " " + r.spkr
'''
def mkWindow(i,xs):
print xs
w = xs[i:i+wSize]
s = ' '.join(w)
#print s
#exit()
return s
wSize = 200
'''
def window(df):
samples = df[['Div', 'Scene','Speaker','txt']]
samples = samples[ ( samples.Div==int(divId) )]
#print samples
samples['sample'] = df.apply(lambda row: proc(row), axis=1)
#print samples[:10]
# NB use speeches as the windows ....
windows = samples['sample']
windows = np.array(windows)
samples['labs'] = df.apply(lambda row: labels(row), axis=1)
labs = samples['labs']
return labs, windows
q_labs, q_ws = window(qf)
f_labs, f_ws = window(ff)
q_labs = np.array(q_labs)
f_labs = np.array(f_labs)
#------------------------------------------------------------------
# Compute Metrics of Q-F speech matrix ....
qfSpMatrix = []
from progressBar import progress
total = len(q_ws)
k=0
for q in q_ws:
progress(k, total, status='')
k+=1
pairs = []
for f in f_ws:
edist = ed(q,f)
dist, al = dtw(q.split(' '), f.split(' '))
# For nmi and jsd the matrix must by NxN
# ... so strings must be aligned ...
q_,f_ = al
nmi_score = nmi(q_,f_)
jsd_score = jsd(q_,f_)
dat = nmi_score, edist, jsd_score, dist
pairs.append(dat)
pairs = np.array(pairs).T
mins = map(np.argmin, pairs)[1:]
# .. for nmi need to find max not min
maxs = [map(np.argmax, pairs)[0]]
posns = maxs + mins
scoqfSpMatrix = []
for i,p in enumerate(posns):
scoqfSpMatrix.append(pairs[i][p])
qfSpMatrix.append(pairs)
qfSpMatrix = np.array(qfSpMatrix)
def selectMetric(a,i):
return np.array([ np.round(ys[i],2) for ys in a ])
#metricNames = "Normalised Mutual Information Score", "Edit Distance", "Jensen Shannon Divergence" # , "DTW Distance"
metricId = 1 # ed
metrics = 'nmi ed jsd dtw'.split(' ')
metric = metrics[metricId]
qfSpEdMatrix = selectMetric(qfSpMatrix, metricId).T
#------------------------------------------------------------------
# Find the optimal matches in the Q-F speech edit dist matrix ....
def getPositions(qfSpEdMatrix):
global metric
print metric
qfSpEdMatchList = []
fPos = 0
qPos = 0
for r in qfSpEdMatrix:
if metric=='nmi': qPos,val = r.argmax(), np.max(r)
else: qPos,val = r.argmin(), np.min(r)
qfSpEdMatchList.append((qPos,fPos,val))
fPos +=1
return qfSpEdMatchList
qfSpEdMatchList = getPositions(qfSpEdMatrix)
#------------------------------------------------------------------
# Recover speech pairs ...
qoff = 0
foff = 0
speechPairs = []
for i,x in enumerate(qfSpEdMatchList):
q,f,v = x
q_ = i-qoff
f_ = i-foff
q_txt = q_ws[i-qoff]
f_txt = f_ws[i-foff]
k = len(q_txt+f_txt)/2
val = float(v) / k
f_len = float(len(f_txt))
q_len = float(len(q_txt))
lenRatio = f_len/q_len
if len(f_txt) > len(q_txt): lenRatio = 1/lenRatio
#print len(f_txt), len(q_txt), val
metadata = i,q,f,qoff,foff,q_,f_,v,val,lenRatio
if q+qoff!=i:
qData = q_, "..."
qoff += 1
else:
qData = q_labs[q_], q_ws[q_]
if f+foff!=i:
fData = f_, "..."
foff += 1
else:
fData = f_labs[f_], f_ws[f_]
data = val, lenRatio, metadata, qData, fData
speechPairs.append(data)
#------------------------------------------------------------------
# Filter the additions only .....
def isAddition(speechPair):
val, lenRatio, metadata, qData, fData = speechPair
qLabel,qTxt = qData
fLabel,fTxt = fData
# Lets just consider speeches marked up as additions (by GE) ...
pred = '>' in qTxt or '>' in fTxt
#print pred
pred = val>float(edThreshold) and (lenRatio > .1 and lenRatio < .9)
return pred
speechPairs = filter(isAddition, speechPairs)
#------------------------------------------------------------------
# Output the results ....
def isMultiLine(q,f):
numLinesQ = len(q.split('@'))
numLinesF = len(f.split('@'))
#print numLinesF,numLinesQ
return numLinesQ > 2 or numLinesF > 2
def isAdd((x,y)):
return ''.join(set(x))=='_' or ''.join(set(y))=='_'
def fixAdd((x,y)):
if ''.join(set(x))=='_' : x = y
if ''.join(set(y))=='_' : y = x
return x,y
for speechPair in speechPairs:
val, lenRatio, metadata, qData, fData = speechPair
qLabel,qTxt = qData
fLabel,fTxt = fData
'''
shorter = qTxt
longer = fTxt
if len(qTxt)>len(fTxt):
shorter = fTxt
longer = qTxt
'''
'''
# Try to use DTW to find where the differences are ...
qs = qTxt.split(' ')
fs = fTxt.split(' ')
da,al = dtw(qs,fs)
qa,fa = al
"""
qs = enumerate(qa)
fs = enumerate(fa)
xs = zip(qs,fs)
#print xs
"""
#qTxt_, fTxt_ = zip(*filter(isAdd, al_))
#qTxt_, fTxt_ = zip(*map(fixAdd, al_))
qa_ = ' '.join(qa)
fa_ = ' '.join(fa)
'''
k = len(q_txt+f_txt)/2
normDist = dist / float(k)
#if qa == "...": normDist = 100.
#if fa == "...": normDist = 100.
#if len(qa)>len(fa): normDist = 1/normDist
if 1: # normDist < float(dtwThreshold): # and isMultiLine(qa,fa) :
#print metadata # dist, normDist, metadata
print qLabel, qTxt
print fLabel, fTxt
# print "========================== DTW Alignment:"
# print qa_
# print fa_
print '-'*60