Beispiel #1
0
    def samediff_val(normalise=True):
        # Embed validation
        np.random.seed(options_dict["rnd_seed"])
        val_batch_iterator = batching.SimpleIterator(val_x, len(val_x), False)
        labels = [val_labels[i] for i in val_batch_iterator.indices]
        speakers = [val_speakers[i] for i in val_batch_iterator.indices]
        saver = tf.train.Saver()
        with tf.Session() as session:
            saver.restore(session, val_model_fn)
            for batch_x_padded, batch_x_lengths in val_batch_iterator:
                np_x = batch_x_padded
                np_x_lengths = batch_x_lengths
                np_z = session.run([encoding],
                                   feed_dict={
                                       x: np_x,
                                       x_lengths: np_x_lengths
                                   })[0]
                break  # single batch

        embed_dict = {}
        for i, utt_key in enumerate(
            [val_keys[i] for i in val_batch_iterator.indices]):
            embed_dict[utt_key] = np_z[i]

        # Same-different
        if normalise:
            np_z_normalised = (np_z - np_z.mean(axis=0)) / np_z.std(axis=0)
            distances = pdist(np_z_normalised, metric="cosine")
        else:
            distances = pdist(np_z, metric="cosine")
        # matches = samediff.generate_matches_array(labels)
        # ap, prb = samediff.average_precision(
        #     distances[matches == True], distances[matches == False]
        #     )
        word_matches = samediff.generate_matches_array(labels)
        speaker_matches = samediff.generate_matches_array(speakers)
        sw_ap, sw_prb, swdp_ap, swdp_prb = samediff.average_precision_swdp(
            distances[np.logical_and(word_matches, speaker_matches)],
            distances[np.logical_and(word_matches, speaker_matches == False)],
            distances[word_matches == False])
        # return [sw_prb, -sw_ap, swdp_prb, -swdp_ap]
        return [swdp_prb, -swdp_ap]
def main():
    args = check_argv()

    print(datetime.now())

    print("Reading:", args.npz_fn)
    npz = np.load(args.npz_fn)

    print(datetime.now())

    print("Ordering embeddings")
    n_embeds = 0
    X = []
    ids = []
    for label in sorted(npz):
        ids.append(label)
        X.append(npz[label])
        n_embeds += 1
    X = np.array(X)
    print("No. embeddings:", n_embeds)
    print("Embedding dimensionality:", X.shape[1])

    if args.mvn:
        normed = (X - X.mean(axis=0)) / X.std(axis=0)
        X = normed

    print(datetime.now())

    print("Calculating distances")
    metric = args.metric
    if metric == "kl":
        import scipy.stats
        metric = scipy.stats.entropy
    distances = pdist(X, metric=metric)

    print(datetime.now())

    print("Getting labels and speakers")
    labels = []
    speakers = []
    for utt_id in ids:
        utt_id = utt_id.split("_")
        word = utt_id[0]
        speaker = utt_id[1]
        labels.append(word)
        speakers.append(speaker)

    if args.mean_ap:
        print(datetime.now())
        print("Calculating mean average precision")
        mean_ap, mean_prb, ap_dict = samediff.mean_average_precision(
            distances, labels)
        print("Mean average precision:", mean_ap)
        print("Mean precision-recall breakeven:", mean_prb)

    print(datetime.now())

    print("Calculating average precision")
    # matches = samediff.generate_matches_array(labels)  # Temp
    word_matches = samediff.generate_matches_array(labels)
    speaker_matches = samediff.generate_matches_array(speakers)
    print("No. same-word pairs:", sum(word_matches))
    print("No. same-speaker pairs:", sum(speaker_matches))

    sw_ap, sw_prb, swdp_ap, swdp_prb = samediff.average_precision_swdp(
        distances[np.logical_and(word_matches, speaker_matches)],
        distances[np.logical_and(word_matches, speaker_matches == False)],
        distances[word_matches == False])
    print("-" * 79)
    print("Average precision: {:.8f}".format(sw_ap))
    print("Precision-recall breakeven: {:.8f}".format(sw_prb))
    print("SWDP average precision: {:.8f}".format(swdp_ap))
    print("SWDP precision-recall breakeven: {:.8f}".format(swdp_prb))
    print("-" * 79)

    print(datetime.now())