def index(request): label = 'Not Start' global accel_data dtw = DTW() dtw.tamplate_load('model.npy') threshold = 20 if (request.method == 'POST'): received_json_data = json.loads(request.body) accel_data.append( Vector( (received_json_data['accel_x'], received_json_data['accel_y'], received_json_data['accel_z']))) if (len(accel_data) == 150): #amplitudes = list(map(lambda v: v.modulus(), accel_data)) #print(amplitudes) #if (max(amplitudes) - min(amplitudes) >= threshold): label = dtw.test(accel_data, cosine_cost) accel_data = [] return HttpResponse(label) else: ''' motion_period = 5 # 5s frequency = 50 # 50Hz period = motion_period * frequency gesture_dataflow_path = 'simulation.csv' gestures = pre.data_segmentation(gesture_dataflow_path, period, 2, 5) for i in range(0, len(gestures)): label = dtw.test(gestures[i], cosine_cost) ''' return HttpResponse('Wrong POST')
def get_dtw_alignment(feat1, feat2): distance_array = cosine_distance(feat1, feat2) _, _, paths = DTW(feat1, feat2, return_alignment=True, dist_array=distance_array) path1, path2 = paths[1:] assert len(path1) == len(path2) return path1, path2
def score_ntdw(): ''' Calculate ntdw between different experimental settings for paper Table 3. ''' EXPS = [ 'sim_results/baseline/submit_coda.json', 'sim_results/baseline/submit_coda_theta_1.json', 'sim_results/baseline/submit_coda_theta_2.json', 'sim_results/baseline/submit_coda_theta_3.json', 'sim_results/baseline_color_jittered/submit_coda_theta_1.json', 'sim_results/baseline_color_jittered/submit_coda_theta_2.json', 'sim_results/baseline_color_jittered/submit_coda_theta_3.json', 'bags/wmap/submit_coda_robot_wmap.json', # Robot with map 'bags/nomap/submit_coda_robot_nomap.json' ] # Robot no map ntdw = np.zeros((len(EXPS), len(EXPS))) scorer = DTW() print('\nCalculating ntdw matrix for %s' % EXPS) for i, ref_path in enumerate(EXPS): with open('src/vln_evaluation/data/%s' % ref_path) as f: ref = json.load(f) for item in ref: item['path_id'] = int(item['instr_id'].split('_')[0]) evaluator = SimEvaluation(ref, CONN) for j, pred_path in enumerate(EXPS): with open('src/vln_evaluation/data/%s' % pred_path) as f: pred = json.load(f) summary, scores = evaluator.score(pred) ntdw[i, j] = summary['ndtw'] print(ntdw)
def half_asym_fn(A, B): return DTW.dist_damn_adv( A, B, [(-1, -3), (-1, -2), (-1, -1), (-2, -1), (-3, -1)], [[1 / 3, 1 / 3, 1 / 3], [1 / 2, 1 / 2], [1], [1, 1], [1, 1, 1]], )
def __init__(self, gt_data, nav_graph_dir): self.error_margin = 3.0 self.dtw = DTW(threshold=self.error_margin) self.graphs = load_nav_graph(nav_graph_dir) self.gt = {} self.instr_ids = [] for item in gt_data: if 'scan' not in item: item['scan'] = 'yZVvKaJZghh' if 'trajectory' in item: item['trajectory'] = self.path_to_points( item['trajectory'], item['scan']) if 'path' in item: item['trajectory'] = self.path_to_points( item['path'], item['scan']) if 'instr_id' in item: self.gt[item['instr_id']] = item else: for i in range(3): self.gt['%s_%d' % (item['path_id'], i)] = item self.instr_ids += [ '%d_%d' % (item['path_id'], i) for i in range(3) ] self.instr_ids = set(self.instr_ids) self.distances = {} # compute all shortest paths self.distances = {} for scan, G in self.graphs.items(): # compute all shortest paths self.distances[scan] = dict(nx.all_pairs_dijkstra_path_length(G))
def calculate_signature_dissimilarity(self, target_signature): """ Calculate the dissimilarity of target_signature the the user's enrolment signatures :param target_signature: list of the normalized feature vectors of the target signature to compare to this user's enrolment signatures :return: dictionary of type <user enrolment signature id> -> <dissimilarity of target signature to that enrolment signature> """ dtw = DTW() dissimilarities = [] for key, enrolment_signature in self.enrolment_signatures.items(): dissim, matrix = dtw.distance(enrolment_signature, target_signature, self.WINDOW_SIZE) dissimilarities.append(dissim) return np.mean(np.asarray(dissimilarities))
def RunDTW(self): matrizDistancia = [] vetorResultado = [] # Faz a leitura da base de dados e vai comparando for linha in self.arquivo: vetorResultado = [0,0,0] valoresLinha = linha.split(" ") vetorResultado[0] = valoresLinha[0] vetorLinha = list(map(float,valoresLinha[1:])) vetorResultado[1] = DTW(vetorLinha,self.vetordeteste).peso() #Retorna o peso do caminho matrizDistancia.append(vetorResultado) return matrizDistancia
def get_dtw_alignement(intvl1, intvl2): feats1, feats2 = map(features_getter.get_features, [intvl1, intvl2]) distance_array = cosine_distance(feats1, feats2) try: cost, _, paths = DTW(feats1, feats2, return_alignment=True, dist_array=distance_array) except: print('\n\n'.join( str(x) for x in [ feats1.shape, feats2.shape, distance_array.shape, intvl1, intvl2 ])) path1, path2 = paths[1:] assert len(path1) == len(path2) return path1, path2
def classifySeq(self, input_seq): # print('input seq:', input_seq) self.class_labels = np.copy(self.c_labels) self.distances = np.zeros((len(self.X)), dtype=float) dtw = DTW() dtw.setY(input_seq) for i, seq in enumerate(self.X): # print('\tcomparing with', i) dtw.setX(seq) self.distances[i] = dtw.perform_dtw() bubble_sort(self.K, self.distances, self.class_labels) self.bins = np.bincount(self.class_labels[:self.K]) self.assigned_classlabel = self.bins.argmax() # print('assigned class label: ', self.assigned_classlabel) return self.assigned_classlabel
def main(): context = zmq.Context() socket: zmq.Socket = context.socket(zmq.PUB) done: zmq.Socket = context.socket(zmq.SUB) done.subscribe('') done.setsockopt(zmq.RCVTIMEO, 2000) socket.connect(DATA_ENDPOINT) done.connect(COMMAND_ENDPOINT) predictors = [] models, false_model = get_models(sys.argv[1]) for model in models: x = [recording for recording, _ in model] sample_len = int( np.ceil(np.average([recording_len for _, recording_len in model]))) predictor = DTW(dtw_cost, sample_len) predictor.compile(dtw_k) predictor.train(x, [[0.0, 1.0] for _ in range(len(model))]) predictors.append(predictor) def cb(found): socket.send_string('true' if found else 'false') print('found!', found) audio_handler = AudioHandler(cb, predictors) audio_handler.start() signal_received = False def signal_cb(_, __): nonlocal signal_received signal_received = True signal.signal(signal.SIGINT, signal_cb) print('started') socket.send_string('started') while not signal_received: try: string = done.recv_string() if string == 'done': signal_received = True break except zmq.ZMQError as err: print('Error while receiving: ', err) time.sleep(1) audio_handler.stop() socket.close() done.close()
def one_sym_fn(A, B): return DTW.dist_damn_adv(A, B, [(-1, -2), (-1, -1), (-2, -1)], [[2, 1], [2], [2, 1]])
same_spkrs = 0 diff_spkrs = 0 with open("SysE.zs") as rf: cword = '' fs = [] for line in rf: l = line.rstrip('\n') if l == '': continue if "Class" in line: cword = l.split()[1] fs = [] else: fname, start, end = l.split() start = int(float(start) * FBANKS_RATE) end = int(float(end) * FBANKS_RATE) tmp = do_fbank(fname)[start:end+1] for (fname2, tmp2) in fs: dtw = DTW(tmp, tmp2, return_alignment=1) spkr1 = fname[:3] spkr2 = fname2[:3] if spkr1 == spkr2: same_spkrs += 1 else: diff_spkrs += 1 pairs.append((cword, spkr1, spkr2, tmp, tmp2, dtw[0], dtw[-1][1], dtw[-1][2])) fs.append((fname, tmp)) joblib.dump(pairs, "from_aren.joblib", compress=3, cache_size=512) print "ratio same spkrs / all:", float(same_spkrs) / (same_spkrs + diff_spkrs)
def normal_asym_fn(A, B): return DTW.dist_damn_adv(A, B, [(0, -1), (-1, -1), (-1, 0)], [[1], [1], [1]])
"train": File.open_with_label("Beef_TRAIN"), "test": File.open_with_label("Beef_TEST"), "fn": two_sym_fn, }, { "title": "Beef.P=2.asym", "train": File.open_with_label("Beef_TRAIN"), "test": File.open_with_label("Beef_TEST"), "fn": two_asym_fn, }, ] for i, job in enumerate(jobs): print("job:", i, "of", len(jobs)) print("job:", job["title"]) result = {"title": job["title"]} start_time = time.process_time() accuracy, correctness = DTW.predict_list(job["train"], job["test"], job["fn"]) end_time = time.process_time() result["accuracy"] = accuracy result["correctness"] = correctness result["total"] = len(job["test"]) result["time_elapsed"] = end_time - start_time File.write_json(job["title"] + ".json", result)
def do_dtw(x1, x2): dtw = DTW(x1, x2, return_alignment=1) return dtw[0], dtw[-1][1], dtw[-1][2]
for i in xrange(len(source_list)): target = STF() target.loadfile(target_list[i]) mfcc = MFCC(target.SPEC.shape[1] * 2, target.frequency, dimension = DIMENSION) target_mfcc = numpy.array([mfcc.mfcc(target.SPEC[frame]) for frame in xrange(target.SPEC.shape[0])]) target_data = numpy.hstack([target_mfcc, mfcc.delta(target_mfcc)]) source = STF() source.loadfile(source_list[i]) mfcc = MFCC(source.SPEC.shape[1] * 2, source.frequency) source_mfcc = numpy.array([mfcc.mfcc(source.SPEC[frame]) for frame in xrange(source.SPEC.shape[0])]) dtw = DTW(source_mfcc, target_mfcc, window = abs(source.SPEC.shape[0] - target.SPEC.shape[0]) * 2) warp_mfcc = dtw.align(source_mfcc) warp_data = numpy.hstack([warp_mfcc, mfcc.delta(warp_mfcc)]) data = numpy.hstack([warp_data, target_data]) if learn_data is None: learn_data = data else: learn_data = numpy.vstack([learn_data, data]) square_mean = (square_mean * (learn_data.shape[0] - target_mfcc.shape[0]) + (target_mfcc ** 2).sum(axis = 0)) / learn_data.shape[0] mean = (mean * (learn_data.shape[0] - target_mfcc.shape[0]) + target_mfcc.sum(axis = 0)) / learn_data.shape[0] gmm = sklearn.mixture.GMM(n_components = 2, covariance_type = 'full') gmm.fit(learn_data)
} # using Euclidean print('using Euclidean...') correctness = 0 for i, each in enumerate(testing): print(i, "of", len(testing)) label = each[0] data = each[1:] min_dist = float('Inf') min_pos = -1 for j, candidate in enumerate(training_with_label): dist = DTW.euclid_dist(data, candidate['data']) if dist < min_dist: min_dist = dist min_pos = j predicted = training_with_label[min_pos]['label'] if predicted == label: correctness += 1 print('min_dist:', min_dist) print('min_pos:', min_pos) print('predicted label:', predicted) print('actual label:', label)
def do_dtw_pair(p1, p2): dtw = DTW(p1[2], p2[2], return_alignment=1) # word, talkerX, talkerY, x, y, cost_dtw, dtw_x_to_y_mapping, dtw_y_to_x_mapping return p1[0], p1[1], p2[1], p1[2], p2[2], dtw[0], dtw[-1][1], dtw[-1][2]
abc = [a, b, c] zeros = abc.count(0) if zeros == 3: continue if zeros == 2: if b == 0: continue print('c:', c) start_time = time.process_time() accuracy, correctness = DTW.predict_list(job['train'], job['test'], [a, b, c], [ [-1, 0], [0, -1], [-1, -1] ]) print('[', a, b, c, ']', 'accuracy:', accuracy, 'correctness:', correctness) end_time = time.process_time() time_elasped = end_time - start_time print('time_elapsed:', time_elasped) result = { 'a': a, 'b': b, 'c': c, 'accuracy': accuracy, 'correctness': correctness, 'total': len(job['test']), 'time_elapsed': time_elasped
def two_asym_fn(A, B): return DTW.dist_damn_adv(A, B, [(-2, -3), (-1, -1), (-3, -2)], [[2 / 3, 2 / 3, 2 / 3], [1], [1, 1, 1]])
def one_asym_fn(A, B): return DTW.dist_damn_adv(A, B, [(-1, -2), (-1, -1), (-2, -1)], [[1 / 2, 1 / 2], [1], [1, 1]])
def main(): X = np.array( [[ -1.870935e-01, -1.941152e+01, -5.882407e+00, -6.821157e+00, -1.960176e+01, -9.122212e+00, -5.398831e+00, 1.076720e+01, 6.568266e+00, -4.519125e+00, 3.895056e+00, 5.133921e+00, 8.817558e-01, -3.344995e-01, -5.670179e-01, 1.977534e-01, -2.825252e+00, 5.670564e-02, -4.054463e-01, -1.570377e+00, 2.622198e-01, -1.543853e+00, -2.045755e-01, -3.527825e-01, -9.646030e-01, 2.183463e-02, -1.032675e-01, -2.037469e-01, 6.428432e-02, 3.759445e-01, 1.421797e-01, 6.430891e-01, 4.654990e-01, 3.725418e-01, 1.287469e-01, -1.014837e-01, -2.791396e-01, -1.542485e-01, 2.186694e-03 ], [ -1.019952e+00, -1.993099e+01, -4.057678e+00, -1.526845e+01, -2.177533e+01, -1.245389e+01, -9.609365e+00, 1.133379e+01, 1.232076e+00, -6.754516e+00, 2.665905e+00, 2.272804e+00, 9.301512e-01, -4.683660e-01, -9.745451e-01, 2.767194e-01, -2.455450e+00, 5.393829e-01, 9.000364e-01, -8.480091e-01, 1.262016e+00, -1.286797e+00, -2.833939e-01, -7.601787e-01, -1.068123e+00, 3.039335e-02, -1.688788e-01, -1.540181e-01, 1.790668e-01, 9.829438e-01, 1.804260e-01, 8.561022e-01, 8.271446e-01, 4.591559e-01, 3.011147e-01, -1.383719e-01, -4.674556e-01, -2.046140e-01, -3.690498e-04 ], [ -1.443142e+00, -2.198701e+01, -5.805995e+00, -1.672378e+01, -1.823138e+01, -9.483617e+00, -1.114543e+01, 1.179495e+01, 1.517057e+00, -4.424240e+00, 2.745863e+00, 1.741368e+00, 9.667338e-01, -7.838881e-01, -1.382024e+00, 4.797013e-01, -1.130435e+00, 5.262859e-01, 2.157257e+00, 3.958666e-01, 1.625053e+00, -1.028640e+00, -6.725525e-01, -1.544723e+00, -1.684104e+00, 2.848946e-02, -1.810844e-01, 7.951448e-02, 2.598753e-01, 1.174168e+00, 6.566963e-02, 4.949601e-01, 8.592530e-01, 3.002176e-01, 3.606367e-01, -2.307723e-01, -4.826202e-01, -2.030101e-01, -4.236597e-03 ], [ -1.900802e+00, -2.299645e+01, -4.536950e+00, -1.414699e+01, -1.758996e+01, -4.441329e+00, -6.765486e+00, 1.656342e+01, 2.659910e+00, -5.983455e+00, 6.687205e-01, 1.489522e+00, 9.912346e-01, -9.541904e-01, -9.296216e-01, 9.521492e-01, 1.242009e+00, 7.240354e-01, 2.593670e+00, 1.582148e+00, 1.876561e+00, -2.959121e-01, -6.624181e-01, -2.094018e+00, -1.627936e+00, 1.666238e-02, -6.888758e-02, 3.390231e-01, 3.715679e-02, 6.889561e-01, -7.894406e-02, -1.904353e-01, 3.151930e-01, -2.784935e-01, 3.556462e-01, -9.185118e-02, -8.689702e-02, 1.021361e-01, -7.807227e-03 ], [ -3.666164e+00, -2.478892e+01, -3.244268e+00, -1.303400e+01, -1.906302e+01, -2.341908e+00, -4.841555e+00, 1.627757e+01, 7.112659e-01, -8.267406e+00, -2.829990e+00, -2.894825e+00, 9.936614e-01, -9.970276e-01, -1.918873e-01, 1.159399e+00, 1.196834e+00, 2.927031e-01, 1.222421e+00, 1.510809e+00, 1.455974e+00, -2.361124e-01, -1.168944e+00, -2.098960e+00, -1.699771e+00, 7.517332e-03, 1.816315e-01, 3.902620e-01, -3.455281e-01, -2.452208e-01, -1.617719e-01, -7.106320e-01, -5.168278e-01, -8.029260e-01, 1.800197e-01, 7.171424e-02, 3.953680e-01, 5.738304e-01, -7.401578e-03 ], [ -4.679300e+00, -2.317812e+01, -5.777994e-01, -1.090323e+01, -1.773932e+01, -3.056363e+00, -4.850668e+00, 1.847520e+01, 1.554637e-01, -8.145046e+00, -5.016462e+00, -3.548824e+00, 1.000000e+00, -7.062657e-01, 1.255049e-01, 1.227021e-01, -1.742893e-01, 2.614480e-01, 4.152198e-01, 1.704610e-01, -4.592960e-02, 9.517462e-02, -4.944558e-01, -9.176090e-01, -5.496670e-01, 1.843350e-03, 4.214767e-01, 2.275055e-01, -5.783656e-01, -8.267415e-01, -2.690008e-01, -6.347027e-01, -9.392050e-01, -1.088750e+00, -6.555431e-02, 2.584541e-01, 8.261142e-01, 8.615100e-01, -5.349388e-03 ], [ -5.039043e+00, -2.285567e+01, -1.988579e+00, -1.236157e+01, -1.669317e+01, -4.063885e+00, -4.548844e+00, 1.811908e+01, 1.588668e+00, -9.188303e+00, -4.906452e+00, -4.238289e+00, 9.999387e-01, 3.074194e-04, 4.170531e-02, -8.331995e-01, -1.648457e+00, -5.129647e-02, -3.066546e-01, -1.482461e+00, -1.428380e+00, -3.240738e-01, -3.979310e-01, -1.560897e-01, 6.459251e-01, -1.109232e-03, 5.914558e-01, 1.447010e-01, -3.698424e-01, -2.614930e-01, -2.379599e-01, -3.071502e-01, -8.654104e-01, -9.751623e-01, -2.588414e-01, 3.338255e-01, 9.100384e-01, 8.658913e-01, -4.769736e-03 ], [ -4.745595e+00, -2.333554e+01, -4.551358e+00, -1.535441e+01, -1.746766e+01, -1.504144e+00, -6.059576e+00, 1.541317e+01, 2.697018e+00, -7.995371e+00, -2.880983e+00, -5.871685e-01, 9.973097e-01, 6.545821e-01, 9.114037e-02, -9.433860e-01, -1.469063e+00, -4.489590e-01, 1.847529e-01, -1.617276e+00, -2.124964e+00, -5.797045e-01, 2.443566e-01, 1.065121e+00, 1.506727e+00, -5.771380e-03, 5.974856e-01, 3.222021e-01, 1.184963e-01, 7.826490e-01, -7.407951e-02, -4.043330e-01, -3.859515e-01, -5.287699e-01, -3.860550e-01, -1.046112e-01, 4.986993e-01, 3.908327e-01, -6.098467e-03 ], [ -3.631593e+00, -2.450166e+01, -5.423533e+00, -1.905079e+01, -1.945535e+01, -4.651355e+00, -1.164923e+01, 1.066649e+01, -2.179882e+00, -1.033172e+01, -4.678272e+00, -1.146063e+00, 9.894610e-01, 1.279829e+00, 5.487746e-01, -1.567750e-01, 5.368466e-01, -5.418938e-01, -1.980120e-01, -1.922416e+00, -2.380239e+00, -1.192883e+00, 1.307755e-01, 1.459865e+00, 1.601482e+00, -1.252421e-02, 3.835591e-01, 4.835010e-01, 3.622566e-01, 1.399534e+00, 1.932966e-01, -4.817923e-01, 1.580537e-01, -4.482182e-02, -1.587683e-01, -4.263003e-01, 7.625756e-02, -4.941015e-02, -6.395079e-03 ], [ -2.110161e+00, -2.189944e+01, -3.577273e+00, -1.490368e+01, -1.860307e+01, -1.838648e+00, -9.386616e+00, 1.157644e+01, -8.588143e-01, -6.351494e+00, 1.952272e-01, 2.438842e+00, 9.763833e-01, 1.641409e+00, 1.482953e+00, 3.769910e-01, 2.646360e+00, 1.363502e-01, -1.660754e+00, -1.539307e+00, -2.213823e+00, -1.400694e+00, -1.281913e+00, 7.679644e-01, 9.266631e-01, -2.294159e-02, 6.459616e-02, 3.304029e-01, 2.195285e-01, 9.616777e-01, 3.012913e-01, -6.065879e-01, 2.680732e-01, 2.129893e-01, 4.940115e-03, -6.909447e-01, -4.549990e-01, -3.945748e-01, -3.847875e-03 ], [ 4.239744e-02, -2.082980e+01, -3.259567e+00, -9.902978e+00, -1.883495e+01, -4.886611e+00, -1.249734e+01, 8.136093e+00, -2.597832e+00, -9.356349e+00, 8.548814e-01, 2.256215e+00, 9.477813e-01, 1.424646e+00, 1.763293e+00, 3.178869e-01, 3.291507e+00, 6.225432e-01, -1.792925e+00, -7.311083e-01, -1.608112e+00, -7.074323e-01, -1.766296e+00, 3.738186e-01, 6.888947e-01, -2.449871e-02, -1.485645e-01, 8.271299e-02, -3.127924e-02, 2.166191e-01, 2.373298e-01, -4.095209e-01, 2.766629e-01, 2.636068e-01, 1.393205e-01, -5.765115e-01, -5.688236e-01, -4.159141e-01, -9.079840e-04 ], [ 1.624431e+00, -1.775662e+01, -3.748471e+00, -6.696738e+00, -1.709612e+01, -9.690301e+00, -1.333205e+01, 5.609426e+00, -4.097603e+00, -1.489259e+01, -1.807804e+00, 2.345169e+00, 9.034401e-01, 9.051237e-01, 1.135900e+00, -8.310343e-02, 1.962012e+00, 4.752693e-01, -2.050676e+00, -8.725550e-01, -1.446113e+00, -7.977447e-01, -2.261846e+00, -6.668284e-01, -9.852156e-03, -1.902324e-02, -1.992105e-01, -1.321529e-01, -1.321248e-01, -2.698160e-01, 5.306010e-02, -1.037717e-01, 1.192216e-01, 1.697413e-01, 1.115534e-01, -2.455495e-01, -3.910235e-01, -2.571701e-01, 1.331323e-03 ]]) Y = np.array( [[ -1.870935e-01, -1.941152e+01, -5.882407e+00, -6.821157e+00, -1.960176e+01, -9.122212e+00, -5.398831e+00, 1.076720e+01, 6.568266e+00, -4.519125e+00, 3.895056e+00, 5.133921e+00, 8.817558e-01, -3.344995e-01, -5.670179e-01, 1.977534e-01, -2.825252e+00, 5.670564e-02, -4.054463e-01, -1.570377e+00, 2.622198e-01, -1.543853e+00, -2.045755e-01, -3.527825e-01, -9.646030e-01, 2.183463e-02, -1.032675e-01, -2.037469e-01, 6.428432e-02, 3.759445e-01, 1.421797e-01, 6.430891e-01, 4.654990e-01, 3.725418e-01, 1.287469e-01, -1.014837e-01, -2.791396e-01, -1.542485e-01, 2.186694e-03 ], [ -1.019952e+00, -1.993099e+01, -4.057678e+00, -1.526845e+01, -2.177533e+01, -1.245389e+01, -9.609365e+00, 1.133379e+01, 1.232076e+00, -6.754516e+00, 2.665905e+00, 2.272804e+00, 9.301512e-01, -4.683660e-01, -9.745451e-01, 2.767194e-01, -2.455450e+00, 5.393829e-01, 9.000364e-01, -8.480091e-01, 1.262016e+00, -1.286797e+00, -2.833939e-01, -7.601787e-01, -1.068123e+00, 3.039335e-02, -1.688788e-01, -1.540181e-01, 1.790668e-01, 9.829438e-01, 1.804260e-01, 8.561022e-01, 8.271446e-01, 4.591559e-01, 3.011147e-01, -1.383719e-01, -4.674556e-01, -2.046140e-01, -3.690498e-04 ], [ -1.443142e+00, -2.198701e+01, -5.805995e+00, -1.672378e+01, -1.823138e+01, -9.483617e+00, -1.114543e+01, 1.179495e+01, 1.517057e+00, -4.424240e+00, 2.745863e+00, 1.741368e+00, 9.667338e-01, -7.838881e-01, -1.382024e+00, 4.797013e-01, -1.130435e+00, 5.262859e-01, 2.157257e+00, 3.958666e-01, 1.625053e+00, -1.028640e+00, -6.725525e-01, -1.544723e+00, -1.684104e+00, 2.848946e-02, -1.810844e-01, 7.951448e-02, 2.598753e-01, 1.174168e+00, 6.566963e-02, 4.949601e-01, 8.592530e-01, 3.002176e-01, 3.606367e-01, -2.307723e-01, -4.826202e-01, -2.030101e-01, -4.236597e-03 ], [ -1.900802e+00, -2.299645e+01, -4.536950e+00, -1.414699e+01, -1.758996e+01, -4.441329e+00, -6.765486e+00, 1.656342e+01, 2.659910e+00, -5.983455e+00, 6.687205e-01, 1.489522e+00, 9.912346e-01, -9.541904e-01, -9.296216e-01, 9.521492e-01, 1.242009e+00, 7.240354e-01, 2.593670e+00, 1.582148e+00, 1.876561e+00, -2.959121e-01, -6.624181e-01, -2.094018e+00, -1.627936e+00, 1.666238e-02, -6.888758e-02, 3.390231e-01, 3.715679e-02, 6.889561e-01, -7.894406e-02, -1.904353e-01, 3.151930e-01, -2.784935e-01, 3.556462e-01, -9.185118e-02, -8.689702e-02, 1.021361e-01, -7.807227e-03 ], [ -3.666164e+00, -2.478892e+01, -3.244268e+00, -1.303400e+01, -1.906302e+01, -2.341908e+00, -4.841555e+00, 1.627757e+01, 7.112659e-01, -8.267406e+00, -2.829990e+00, -2.894825e+00, 9.936614e-01, -9.970276e-01, -1.918873e-01, 1.159399e+00, 1.196834e+00, 2.927031e-01, 1.222421e+00, 1.510809e+00, 1.455974e+00, -2.361124e-01, -1.168944e+00, -2.098960e+00, -1.699771e+00, 7.517332e-03, 1.816315e-01, 3.902620e-01, -3.455281e-01, -2.452208e-01, -1.617719e-01, -7.106320e-01, -5.168278e-01, -8.029260e-01, 1.800197e-01, 7.171424e-02, 3.953680e-01, 5.738304e-01, -7.401578e-03 ], [ -4.679300e+00, -2.317812e+01, -5.777994e-01, -1.090323e+01, -1.773932e+01, -3.056363e+00, -4.850668e+00, 1.847520e+01, 1.554637e-01, -8.145046e+00, -5.016462e+00, -3.548824e+00, 1.000000e+00, -7.062657e-01, 1.255049e-01, 1.227021e-01, -1.742893e-01, 2.614480e-01, 4.152198e-01, 1.704610e-01, -4.592960e-02, 9.517462e-02, -4.944558e-01, -9.176090e-01, -5.496670e-01, 1.843350e-03, 4.214767e-01, 2.275055e-01, -5.783656e-01, -8.267415e-01, -2.690008e-01, -6.347027e-01, -9.392050e-01, -1.088750e+00, -6.555431e-02, 2.584541e-01, 8.261142e-01, 8.615100e-01, -5.349388e-03 ], [ -5.039043e+00, -2.285567e+01, -1.988579e+00, -1.236157e+01, -1.669317e+01, -4.063885e+00, -4.548844e+00, 1.811908e+01, 1.588668e+00, -9.188303e+00, -4.906452e+00, -4.238289e+00, 9.999387e-01, 3.074194e-04, 4.170531e-02, -8.331995e-01, -1.648457e+00, -5.129647e-02, -3.066546e-01, -1.482461e+00, -1.428380e+00, -3.240738e-01, -3.979310e-01, -1.560897e-01, 6.459251e-01, -1.109232e-03, 5.914558e-01, 1.447010e-01, -3.698424e-01, -2.614930e-01, -2.379599e-01, -3.071502e-01, -8.654104e-01, -9.751623e-01, -2.588414e-01, 3.338255e-01, 9.100384e-01, 8.658913e-01, -4.769736e-03 ], [ -4.745595e+00, -2.333554e+01, -4.551358e+00, -1.535441e+01, -1.746766e+01, -1.504144e+00, -6.059576e+00, 1.541317e+01, 2.697018e+00, -7.995371e+00, -2.880983e+00, -5.871685e-01, 9.973097e-01, 6.545821e-01, 9.114037e-02, -9.433860e-01, -1.469063e+00, -4.489590e-01, 1.847529e-01, -1.617276e+00, -2.124964e+00, -5.797045e-01, 2.443566e-01, 1.065121e+00, 1.506727e+00, -5.771380e-03, 5.974856e-01, 3.222021e-01, 1.184963e-01, 7.826490e-01, -7.407951e-02, -4.043330e-01, -3.859515e-01, -5.287699e-01, -3.860550e-01, -1.046112e-01, 4.986993e-01, 3.908327e-01, -6.098467e-03 ], [ -3.631593e+00, -2.450166e+01, -5.423533e+00, -1.905079e+01, -1.945535e+01, -4.651355e+00, -1.164923e+01, 1.066649e+01, -2.179882e+00, -1.033172e+01, -4.678272e+00, -1.146063e+00, 9.894610e-01, 1.279829e+00, 5.487746e-01, -1.567750e-01, 5.368466e-01, -5.418938e-01, -1.980120e-01, -1.922416e+00, -2.380239e+00, -1.192883e+00, 1.307755e-01, 1.459865e+00, 1.601482e+00, -1.252421e-02, 3.835591e-01, 4.835010e-01, 3.622566e-01, 1.399534e+00, 1.932966e-01, -4.817923e-01, 1.580537e-01, -4.482182e-02, -1.587683e-01, -4.263003e-01, 7.625756e-02, -4.941015e-02, -6.395079e-03 ], [ -2.110161e+00, -2.189944e+01, -3.577273e+00, -1.490368e+01, -1.860307e+01, -1.838648e+00, -9.386616e+00, 1.157644e+01, -8.588143e-01, -6.351494e+00, 1.952272e-01, 2.438842e+00, 9.763833e-01, 1.641409e+00, 1.482953e+00, 3.769910e-01, 2.646360e+00, 1.363502e-01, -1.660754e+00, -1.539307e+00, -2.213823e+00, -1.400694e+00, -1.281913e+00, 7.679644e-01, 9.266631e-01, -2.294159e-02, 6.459616e-02, 3.304029e-01, 2.195285e-01, 9.616777e-01, 3.012913e-01, -6.065879e-01, 2.680732e-01, 2.129893e-01, 4.940115e-03, -6.909447e-01, -4.549990e-01, -3.945748e-01, -3.847875e-03 ], [ 4.239744e-02, -2.082980e+01, -3.259567e+00, -9.902978e+00, -1.883495e+01, -4.886611e+00, -1.249734e+01, 8.136093e+00, -2.597832e+00, -9.356349e+00, 8.548814e-01, 2.256215e+00, 9.477813e-01, 1.424646e+00, 1.763293e+00, 3.178869e-01, 3.291507e+00, 6.225432e-01, -1.792925e+00, -7.311083e-01, -1.608112e+00, -7.074323e-01, -1.766296e+00, 3.738186e-01, 6.888947e-01, -2.449871e-02, -1.485645e-01, 8.271299e-02, -3.127924e-02, 2.166191e-01, 2.373298e-01, -4.095209e-01, 2.766629e-01, 2.636068e-01, 1.393205e-01, -5.765115e-01, -5.688236e-01, -4.159141e-01, -9.079840e-04 ], [ 1.624431e+00, -1.775662e+01, -3.748471e+00, -6.696738e+00, -1.709612e+01, -9.690301e+00, -1.333205e+01, 5.609426e+00, -4.097603e+00, -1.489259e+01, -1.807804e+00, 2.345169e+00, 9.034401e-01, 9.051237e-01, 1.135900e+00, -8.310343e-02, 1.962012e+00, 4.752693e-01, -2.050676e+00, -8.725550e-01, -1.446113e+00, -7.977447e-01, -2.261846e+00, -6.668284e-01, -9.852156e-03, -1.902324e-02, -1.992105e-01, -1.321529e-01, -1.321248e-01, -2.698160e-01, 5.306010e-02, -1.037717e-01, 1.192216e-01, 1.697413e-01, 1.115534e-01, -2.455495e-01, -3.910235e-01, -2.571701e-01, 1.331323e-03 ]]) my_warper = DTW() my_warper.setX(X) my_warper.setY(Y) my_warper.perform_dtw() print(my_warper.get_dtw_dist())
def two_sym_fn(A, B): return DTW.dist_damn_adv(A, B, [(-2, -3), (-1, -1), (-3, -2)], [[2, 2, 1], [2], [2, 2, 1]])
print("End of creation of the training dictionnary") keywords = open(folder_unprocessed + "/task/keywords.txt").read().splitlines() validation_dict = data.create_set('data/test') transcriptions_dict = data.get_transcriptions(folder_unprocessed + '/ground-truth/transcription.txt') print("Number of element in the validation dict: %d" % len(validation_dict)) output_lines = [] for keyword in keywords: data = keyword.split(',') output_line = [data[0]] score_dict = {} image = training_dict[data[1]] dtw_o = DTW(image) i = 0 for key, image2 in validation_dict.items(): score = dtw_o.calculate_cost(image2) score_dict[key] = score i += 1 if i % 100 == 0: print(i) sorted_score = dict(sorted(score_dict.items(), key = operator.itemgetter(1))) for name, score in sorted_score.items(): output_line.append(name) output_line.append(score) output_lines.append(output_line) # for key, image in validation_dict.items(): # dtw_o = DTW(image)
def bad_fn(A, B): return DTW.dist_damn_adv(A, B, [(-1, -1)], [[1]])
def do_dtw(word, x, y): dtw = DTW(x, y, return_alignment=1) # word, x, y, cost_dtw, dtw_x_to_y_mapping, dtw_y_to_x_mapping return word, x, y, dtw[0], dtw[-1][1], dtw[-1][2]