#!/usr/bin/python
# -*- coding: utf-8 -*-

# --------------------------------------------------
# File Name: dtw.py
# Purpose:
# Creation Date: 10-03-2017
# Last Modified: Thu, Mar 23, 2017  8:31:16 PM
# Author(s): Mike Stout 
# Copyright 2017 The Author(s) All Rights Reserved
# Credits: 
# --------------------------------------------------


import sys
from mlpy import dtw_subsequence, dtw_std
#from cpu_sDTW  import dtw_subsequence



import numpy as np


import mlpy
import matplotlib.pyplot as plt
import matplotlib.cm as cm

def read(a):
    with open(a) as f:
     #s= f.readlines()
     s= f.read()
    return s

m = int(sys.argv[1])
n = int(sys.argv[2])
f1 = sys.argv[3]
f2 = sys.argv[4]
gapChar = sys.argv[5]

def zscore(xs):
    xs = map(float, map(ord, xs))
    mean = np.mean(xs)
    var = np.var(xs)
    return (np.array(xs)-mean)/ var

x = read(f1)[m:n]
x_ = zscore(x)
y = read(f2)[m:n]
y_ = zscore(y)
# dst,_,pos = dtw_subsequence(x,y)
#print dst,pos

#print np.array(x_)
#print np.array(y_)


dist, cost, path = dtw_std(x_ ,y_, dist_only=False)
dist, cost, path = dtw_subsequence(x_ ,y_)

#def dist(xs,ys): return np.array([[ abs(x-y)  for x in xs ] for y in ys ])
#dst = dist(x_,y_)
#print dst
#exit()

fig = plt.figure(1)
ax = fig.add_subplot(111)

plot1 = plt.imshow(cost.T, origin='lower', cmap=cm.gray, interpolation='nearest')
plot2 = plt.plot(path[0], path[1], 'w')
xlim = ax.set_xlim((-0.5, cost.shape[0]-0.5))
ylim = ax.set_ylim((-0.5, cost.shape[1]-0.5))

plt.show()

def splits(n, f, xs):
    return [ f+":\t" + xs[i:i+n] for i in range(0, len(xs), n)]

def fix(xs, i, path):
    if i>0 and path[i]==path[i-1]: return gapChar
    else: return xs[path[i]]

def recover(xs, path):
    return ''.join([ fix(xs, i, path) for i in xrange(len(path))]).replace("\n"," ")


n = 200 
x__ = recover(x,path[0])
y__ = recover(y,path[1])
aa = splits(200, f1, x__)
bb = splits(200, f2, y__)

zz = [ "\n"+a+"\n"+b for a,b in zip(aa,bb) ]

for z in zz: print z

def mrg((x,y)): 
    return x+y # str((x,y))

def merge(xs,ys):
    return ''.join(map(mrg, zip(xs,ys)))

with open("tmp",'w') as f:  f.write(merge(x__, y__))