#!/usr/bin/python # -*- coding: utf-8 -*- # -------------------------------------------------- # File Name: exp_2a.py # Location: # Purpose: # Creation Date: 29-10-2017 # Last Modified: Wed, Jan 31, 2018 9:44:26 PM # Author(s): Mike Stout # Copyright 2017 The Author(s) All Rights Reserved # Credits: # -------------------------------------------------- import sys import numpy as np import pandas as pd import re pd.set_option('display.max_rows', 15) pd.set_option('display.max_columns', 500) pd.set_option('display.width', 100) pd.set_option('chop_threshold', 0) pd.set_option('precision', 4) pd.set_option('display.colheader_justify','right') pd.set_option('display.max_colwidth', -1) pd.set_option('display.max_colwidth', 32) from calcMetrics import * csvfile, divId, edThreshold, dtwThreshold = sys.argv[1:] #------------------------------------------------------------------ # Read in the csv file and split out the required dataframes ... def words(s): #return s.split(' ') return filter(lambda x: x!='', re.split('[ ,;:.\>\<?!#@\n\t\']', s)) df=pd.read_csv(csvfile) if csvfile=="df.csv": # using GE additions marked data ... df = df[ (df.Subset == 'Full-a')] qf = df[ (df.Edtn == 'Q')] ff = df[ (df.Edtn == 'F')] else : df = df[ (df.Subset == 'Full')] qf = df[ (df.Edtn == 'Q1')] ff = df[ (df.Edtn == 'F')] #qf = qf[:20] #ff = ff[:20] #------------------------------------------------------------------ # Extract speeches and labels for each from the q/f dataframes ... def proc(r): #return r.spkr + " " + r.txt return r.txt.strip() def labels(r): return r.Scene + " " + str(r['sp#']-1) + " " + r.spkr ''' def mkWindow(i,xs): print xs w = xs[i:i+wSize] s = ' '.join(w) #print s #exit() return s wSize = 200 ''' def window(df): samples = df[['Div', 'Scene','Speaker','txt']] samples = samples[ ( samples.Div==int(divId) )] #print samples samples['sample'] = df.apply(lambda row: proc(row), axis=1) #print samples[:10] # NB use speeches as the windows .... windows = samples['sample'] windows = np.array(windows) samples['labs'] = df.apply(lambda row: labels(row), axis=1) labs = samples['labs'] return labs, windows q_labs, q_ws = window(qf) f_labs, f_ws = window(ff) q_labs = np.array(q_labs) f_labs = np.array(f_labs) #------------------------------------------------------------------ # Compute Metrics of Q-F speech matrix .... qfSpMatrix = [] from progressBar import progress total = len(q_ws) k=0 for q in q_ws: progress(k, total, status='') k+=1 pairs = [] for f in f_ws: edist = ed(q,f) dist, al = dtw(q.split(' '), f.split(' ')) # For nmi and jsd the matrix must by NxN # ... so strings must be aligned ... q_,f_ = al nmi_score = nmi(q_,f_) jsd_score = jsd(q_,f_) dat = nmi_score, edist, jsd_score, dist pairs.append(dat) pairs = np.array(pairs).T mins = map(np.argmin, pairs)[1:] # .. for nmi need to find max not min maxs = [map(np.argmax, pairs)[0]] posns = maxs + mins scoqfSpMatrix = [] for i,p in enumerate(posns): scoqfSpMatrix.append(pairs[i][p]) qfSpMatrix.append(pairs) qfSpMatrix = np.array(qfSpMatrix) def selectMetric(a,i): return np.array([ np.round(ys[i],2) for ys in a ]) #metricNames = "Normalised Mutual Information Score", "Edit Distance", "Jensen Shannon Divergence" # , "DTW Distance" metricId = 1 # ed metrics = 'nmi ed jsd dtw'.split(' ') metric = metrics[metricId] qfSpEdMatrix = selectMetric(qfSpMatrix, metricId).T #------------------------------------------------------------------ # Find the optimal matches in the Q-F speech edit dist matrix .... def getPositions(qfSpEdMatrix): global metric print metric qfSpEdMatchList = [] fPos = 0 qPos = 0 for r in qfSpEdMatrix: if metric=='nmi': qPos,val = r.argmax(), np.max(r) else: qPos,val = r.argmin(), np.min(r) qfSpEdMatchList.append((qPos,fPos,val)) fPos +=1 return qfSpEdMatchList qfSpEdMatchList = getPositions(qfSpEdMatrix) #------------------------------------------------------------------ # Recover speech pairs ... qoff = 0 foff = 0 speechPairs = [] for i,x in enumerate(qfSpEdMatchList): q,f,v = x q_ = i-qoff f_ = i-foff q_txt = q_ws[i-qoff] f_txt = f_ws[i-foff] k = len(q_txt+f_txt)/2 val = float(v) / k f_len = float(len(f_txt)) q_len = float(len(q_txt)) lenRatio = f_len/q_len if len(f_txt) > len(q_txt): lenRatio = 1/lenRatio #print len(f_txt), len(q_txt), val metadata = i,q,f,qoff,foff,q_,f_,v,val,lenRatio if q+qoff!=i: qData = q_, "..." qoff += 1 else: qData = q_labs[q_], q_ws[q_] if f+foff!=i: fData = f_, "..." foff += 1 else: fData = f_labs[f_], f_ws[f_] data = val, lenRatio, metadata, qData, fData speechPairs.append(data) #------------------------------------------------------------------ # Filter the additions only ..... def isAddition(speechPair): val, lenRatio, metadata, qData, fData = speechPair qLabel,qTxt = qData fLabel,fTxt = fData # Lets just consider speeches marked up as additions (by GE) ... pred = '>' in qTxt or '>' in fTxt #print pred pred = val>float(edThreshold) and (lenRatio > .1 and lenRatio < .9) return pred speechPairs = filter(isAddition, speechPairs) #------------------------------------------------------------------ # Output the results .... def isMultiLine(q,f): numLinesQ = len(q.split('@')) numLinesF = len(f.split('@')) #print numLinesF,numLinesQ return numLinesQ > 2 or numLinesF > 2 def isAdd((x,y)): return ''.join(set(x))=='_' or ''.join(set(y))=='_' def fixAdd((x,y)): if ''.join(set(x))=='_' : x = y if ''.join(set(y))=='_' : y = x return x,y for speechPair in speechPairs: val, lenRatio, metadata, qData, fData = speechPair qLabel,qTxt = qData fLabel,fTxt = fData ''' shorter = qTxt longer = fTxt if len(qTxt)>len(fTxt): shorter = fTxt longer = qTxt ''' ''' # Try to use DTW to find where the differences are ... qs = qTxt.split(' ') fs = fTxt.split(' ') da,al = dtw(qs,fs) qa,fa = al """ qs = enumerate(qa) fs = enumerate(fa) xs = zip(qs,fs) #print xs """ #qTxt_, fTxt_ = zip(*filter(isAdd, al_)) #qTxt_, fTxt_ = zip(*map(fixAdd, al_)) qa_ = ' '.join(qa) fa_ = ' '.join(fa) ''' k = len(q_txt+f_txt)/2 normDist = dist / float(k) #if qa == "...": normDist = 100. #if fa == "...": normDist = 100. #if len(qa)>len(fa): normDist = 1/normDist if 1: # normDist < float(dtwThreshold): # and isMultiLine(qa,fa) : #print metadata # dist, normDist, metadata print qLabel, qTxt print fLabel, fTxt # print "========================== DTW Alignment:" # print qa_ # print fa_ print '-'*60