def samediff_val(normalise=True): # Embed validation np.random.seed(options_dict["rnd_seed"]) val_batch_iterator = batching.SimpleIterator(val_x, len(val_x), False) labels = [val_labels[i] for i in val_batch_iterator.indices] speakers = [val_speakers[i] for i in val_batch_iterator.indices] saver = tf.train.Saver() with tf.Session() as session: saver.restore(session, val_model_fn) for batch_x_padded, batch_x_lengths in val_batch_iterator: np_x = batch_x_padded np_x_lengths = batch_x_lengths np_z = session.run([encoding], feed_dict={ x: np_x, x_lengths: np_x_lengths })[0] break # single batch embed_dict = {} for i, utt_key in enumerate( [val_keys[i] for i in val_batch_iterator.indices]): embed_dict[utt_key] = np_z[i] # Same-different if normalise: np_z_normalised = (np_z - np_z.mean(axis=0)) / np_z.std(axis=0) distances = pdist(np_z_normalised, metric="cosine") else: distances = pdist(np_z, metric="cosine") # matches = samediff.generate_matches_array(labels) # ap, prb = samediff.average_precision( # distances[matches == True], distances[matches == False] # ) word_matches = samediff.generate_matches_array(labels) speaker_matches = samediff.generate_matches_array(speakers) sw_ap, sw_prb, swdp_ap, swdp_prb = samediff.average_precision_swdp( distances[np.logical_and(word_matches, speaker_matches)], distances[np.logical_and(word_matches, speaker_matches == False)], distances[word_matches == False]) # return [sw_prb, -sw_ap, swdp_prb, -swdp_ap] return [swdp_prb, -swdp_ap]
def main(): args = check_argv() print(datetime.now()) print("Reading:", args.npz_fn) npz = np.load(args.npz_fn) print(datetime.now()) print("Ordering embeddings") n_embeds = 0 X = [] ids = [] for label in sorted(npz): ids.append(label) X.append(npz[label]) n_embeds += 1 X = np.array(X) print("No. embeddings:", n_embeds) print("Embedding dimensionality:", X.shape[1]) if args.mvn: normed = (X - X.mean(axis=0)) / X.std(axis=0) X = normed print(datetime.now()) print("Calculating distances") metric = args.metric if metric == "kl": import scipy.stats metric = scipy.stats.entropy distances = pdist(X, metric=metric) print(datetime.now()) print("Getting labels and speakers") labels = [] speakers = [] for utt_id in ids: utt_id = utt_id.split("_") word = utt_id[0] speaker = utt_id[1] labels.append(word) speakers.append(speaker) if args.mean_ap: print(datetime.now()) print("Calculating mean average precision") mean_ap, mean_prb, ap_dict = samediff.mean_average_precision( distances, labels) print("Mean average precision:", mean_ap) print("Mean precision-recall breakeven:", mean_prb) print(datetime.now()) print("Calculating average precision") # matches = samediff.generate_matches_array(labels) # Temp word_matches = samediff.generate_matches_array(labels) speaker_matches = samediff.generate_matches_array(speakers) print("No. same-word pairs:", sum(word_matches)) print("No. same-speaker pairs:", sum(speaker_matches)) sw_ap, sw_prb, swdp_ap, swdp_prb = samediff.average_precision_swdp( distances[np.logical_and(word_matches, speaker_matches)], distances[np.logical_and(word_matches, speaker_matches == False)], distances[word_matches == False]) print("-" * 79) print("Average precision: {:.8f}".format(sw_ap)) print("Precision-recall breakeven: {:.8f}".format(sw_prb)) print("SWDP average precision: {:.8f}".format(swdp_ap)) print("SWDP precision-recall breakeven: {:.8f}".format(swdp_prb)) print("-" * 79) print(datetime.now())