コード例 #1
0
def speech_rnn_latent_values(
    lib,
    feats_fn,
    latent_dict,
):

    x, labels, lengths, keys = (
        data_library.load_speech_data_from_npz(feats_fn))

    data_library.truncate_data_dim(x, lengths, lib["input_dim"],
                                   lib["max_frames"])
    iterator = batching_library.speech_iterator(
        x, 1, shuffle_batches_every_epoch=False)
    indices = iterator.indices

    tf.reset_default_graph()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    key_counter = 0

    X = tf.placeholder(tf.float32, [None, None, lib["input_dim"]])
    target = tf.placeholder(tf.int32, [None, lib["num_classes"]])
    X_lengths = tf.placeholder(tf.int32, [None])
    train_flag = tf.placeholder_with_default(False, shape=())
    training_placeholders = [X, X_lengths, target]

    model = model_legos_library.rnn_classifier_architecture(
        [X, X_lengths], train_flag, model_setup_library.activation_lib(), lib)

    output = model["output"]
    latent = model["latent"]

    model_fn = model_setup_library.get_model_fn(lib)

    saver = tf.train.Saver()
    with tf.Session(config=config) as sesh:
        saver.restore(sesh, model_fn)
        for feats, lengths in tqdm(iterator,
                                   desc="Extracting latents",
                                   ncols=COL_LENGTH):
            lat = sesh.run(latent,
                           feed_dict={
                               X: feats,
                               X_lengths: lengths,
                               train_flag: False
                           })

            latent_dict[keys[indices[key_counter]]] = lat
            key_counter += 1

    print("Total number of keys: {}".format(key_counter))
コード例 #2
0
                def zero_shot_multimodal_task(episode_fn,
                                              sp_test_x,
                                              sp_test_keys,
                                              sp_test_labels,
                                              im_test_x,
                                              im_test_keys,
                                              im_test_labels,
                                              normalize=True):

                    episode_dict = generate_episodes.read_in_episodes(
                        episode_fn)
                    np.random.seed(rnd_seed)
                    episode_numbers = np.arange(1, len(episode_dict) + 1)
                    np.random.shuffle(episode_numbers)
                    correct = 0
                    total = 0

                    saver = tf.train.Saver()
                    with tf.Session() as sesh:
                        saver.restore(sesh, model_fn)

                        for episode in tqdm(
                                episode_numbers,
                                desc=
                                f'\t{0}-shot multimodal matching tests on {len(episode_numbers)} episodes for random seed {rnd_seed:2.0f}',
                                ncols=COL_LENGTH):

                            episode_num = str(episode)
                            # Get query iterator
                            query = episode_dict[episode_num]["query"]
                            query_data, query_keys, query_lab = generate_episodes.episode_data(
                                query["keys"], sp_test_x, sp_test_keys,
                                sp_test_labels)
                            query_iterator = batching_library.speech_iterator(
                                query_data,
                                len(query_data),
                                shuffle_batches_every_epoch=False)

                            matching_set = episode_dict[episode_num][
                                "matching_set"]
                            M_data, M_keys, M_lab = generate_episodes.episode_data(
                                matching_set["keys"], im_test_x, im_test_keys,
                                im_test_labels)
                            M_image_iterator = batching_library.unflattened_image_iterator(
                                M_data,
                                len(M_data),
                                shuffle_batches_every_epoch=False)

                            for feats, lengths in query_iterator:
                                lat = sesh.run(
                                    [speech_latent],
                                    feed_dict={
                                        speech_X: feats,
                                        speech_X_lengths: lengths,
                                        train_flag: False
                                    })[0]
                            for feats in M_image_iterator:
                                S_lat = sesh.run([image_latent],
                                                 feed_dict={
                                                     image_X: feats,
                                                     train_flag: False
                                                 })[0]

                            if normalize:
                                query_latents = (
                                    lat - lat.mean(axis=0)) / lat.std(axis=0)
                                matching_latents = (S_lat - S_lat.mean(axis=0)
                                                    ) / S_lat.std(axis=0)
                            else:
                                query_latents = lat
                                s_latents = S_lat

                            query_latent_labels = [
                                query_lab[i] for i in query_iterator.indices
                            ]
                            query_latent_keys = [
                                query_keys[i] for i in query_iterator.indices
                            ]

                            matching_latent_labels = [
                                M_lab[i] for i in M_image_iterator.indices
                            ]
                            matching_latent_keys = [
                                M_keys[i] for i in M_image_iterator.indices
                            ]

                            distances = cdist(query_latents, matching_latents,
                                              "cosine")
                            indexes = np.argmin(distances, axis=1)

                            label_matches = few_shot_learning_library.label_matches_grid_generation_2D(
                                query_latent_labels, matching_latent_labels)

                            for i in range(len(indexes)):
                                total += 1
                                if label_matches[i, indexes[i]]:
                                    correct += 1

                    return correct / total
コード例 #3
0
                def multimodal_task(episode_fn,
                                    sp_test_x,
                                    sp_test_keys,
                                    sp_test_labels,
                                    im_test_x,
                                    im_test_keys,
                                    im_test_labels,
                                    normalize=True,
                                    k=1):

                    episode_dict = generate_episodes.read_in_episodes(
                        episode_fn)
                    np.random.seed(rnd_seed)
                    episode_numbers = np.arange(1, len(episode_dict) + 1)
                    np.random.shuffle(episode_numbers)
                    correct = 0
                    total = 0

                    saver = tf.train.Saver()
                    with tf.Session() as sesh:
                        saver.restore(sesh, model_fn)

                        for episode in tqdm(
                                episode_numbers,
                                desc=
                                f'\t{k}-shot multimodal matching tests on {len(episode_numbers)} episodes for random seed {rnd_seed:2.0f}',
                                ncols=COL_LENGTH):

                            episode_num = str(episode)
                            # Get query iterator
                            query = episode_dict[episode_num]["query"]
                            query_data, query_keys, query_lab = generate_episodes.episode_data(
                                query["keys"], sp_test_x, sp_test_keys,
                                sp_test_labels)
                            query_iterator = batching_library.speech_iterator(
                                query_data,
                                len(query_data),
                                shuffle_batches_every_epoch=False)

                            # Get speech_support set
                            support_set = episode_dict[episode_num][
                                "support_set"]
                            S_image_data, S_image_keys, S_image_lab = generate_episodes.episode_data(
                                support_set["image_keys"], im_test_x,
                                im_test_keys, im_test_labels)
                            S_speech_data, S_speech_keys, S_speech_lab = generate_episodes.episode_data(
                                support_set["speech_keys"], sp_test_x,
                                sp_test_keys, sp_test_labels)
                            key_list = []
                            for i in range(len(S_speech_keys)):
                                key_list.append(
                                    (S_speech_keys[i], S_image_keys[i]))

                            S_speech_iterator = batching_library.speech_iterator(
                                S_speech_data,
                                len(S_speech_data),
                                shuffle_batches_every_epoch=False)

                            for feats, lengths in query_iterator:
                                lat = sesh.run(
                                    [speech_latent],
                                    feed_dict={
                                        speech_X: feats,
                                        speech_X_lengths: lengths,
                                        train_flag: False
                                    })[0]
                            for feats, lengths in S_speech_iterator:
                                S_lat = sesh.run(
                                    [speech_latent],
                                    feed_dict={
                                        speech_X: feats,
                                        speech_X_lengths: lengths,
                                        train_flag: False
                                    })[0]

                            if normalize:
                                query_latents = (
                                    lat - lat.mean(axis=0)) / lat.std(axis=0)
                                s_speech_latents = (S_lat - S_lat.mean(axis=0)
                                                    ) / S_lat.std(axis=0)
                            else:
                                query_latents = lat
                                s_speech_latents = S_lat

                            query_latent_labels = [
                                query_lab[i] for i in query_iterator.indices
                            ]
                            query_latent_keys = [
                                query_keys[i] for i in query_iterator.indices
                            ]
                            s_speech_latent_labels = [
                                S_speech_lab[i]
                                for i in S_speech_iterator.indices
                            ]
                            s_speech_latent_keys = [
                                S_speech_keys[i]
                                for i in S_speech_iterator.indices
                            ]

                            distances1 = cdist(query_latents, s_speech_latents,
                                               "cosine")
                            indexes1 = np.argmin(distances1, axis=1)

                            chosen_speech_keys = []
                            for i in range(len(indexes1)):
                                chosen_speech_keys.append(
                                    s_speech_latent_keys[indexes1[i]])

                            S_image_iterator = batching_library.unflattened_image_iterator(
                                S_image_data,
                                len(S_image_data),
                                shuffle_batches_every_epoch=False)
                            matching_set = episode_dict[episode_num][
                                "matching_set"]
                            M_data, M_keys, M_lab = generate_episodes.episode_data(
                                matching_set["keys"], im_test_x, im_test_keys,
                                im_test_labels)
                            M_image_iterator = batching_library.unflattened_image_iterator(
                                M_data,
                                len(M_data),
                                shuffle_batches_every_epoch=False)

                            for feats in S_image_iterator:
                                lat = sesh.run([image_latent],
                                               feed_dict={
                                                   image_X: feats,
                                                   train_flag: False
                                               })[0]
                            for feats in M_image_iterator:
                                S_lat = sesh.run([image_latent],
                                                 feed_dict={
                                                     image_X: feats,
                                                     train_flag: False
                                                 })[0]

                            if normalize:
                                s_image_latents = (
                                    lat - lat.mean(axis=0)) / lat.std(axis=0)
                                matching_latents = (S_lat - S_lat.mean(axis=0)
                                                    ) / S_lat.std(axis=0)
                            else:
                                s_image_latents = lat
                                s_latents = S_lat

                            s_image_latent_labels = [
                                S_image_lab[i]
                                for i in S_image_iterator.indices
                            ]
                            s_image_latent_keys = [
                                S_image_keys[i]
                                for i in S_image_iterator.indices
                            ]

                            matching_latent_labels = [
                                M_lab[i] for i in M_image_iterator.indices
                            ]
                            matching_latent_keys = [
                                M_keys[i] for i in M_image_iterator.indices
                            ]

                            image_key_order_list = []
                            s_image_latents_in_order = np.empty(
                                (query_latents.shape[0],
                                 s_image_latents.shape[1]))
                            s_image_labels_in_order = []

                            for j, key in enumerate(chosen_speech_keys):
                                for (sp_key, im_key) in key_list:
                                    if key == sp_key:
                                        image_key_order_list.append(im_key)
                                        i = np.where(
                                            np.asarray(s_image_latent_keys) ==
                                            im_key)[0][0]
                                        s_image_latents_in_order[
                                            j:j + 1, :] = s_image_latents[i:i +
                                                                          1, :]
                                        s_image_labels_in_order.append(
                                            s_image_latent_labels[i])
                                        break

                            distances2 = cdist(s_image_latents_in_order,
                                               matching_latents, "cosine")
                            indexes2 = np.argmin(distances2, axis=1)
                            label_matches = few_shot_learning_library.label_matches_grid_generation_2D(
                                query_latent_labels, matching_latent_labels)

                            for i in range(len(indexes2)):
                                total += 1
                                if label_matches[i, indexes2[i]]:
                                    correct += 1

                    return correct / total
コード例 #4
0
                def unimodal_speech_test(
                        episode_fn=lib['speech_episode_1_list'],
                        data_x=speech_test_x,
                        data_keys=speech_test_keys,
                        data_labels=speech_test_labels,
                        normalize=False,
                        k=1):
                    episode_dict = generate_unimodal_speech_episodes.read_in_episodes(
                        episode_fn)

                    correct = 0
                    total = 0

                    episode_numbers = np.arange(1, len(episode_dict) + 1)
                    np.random.shuffle(episode_numbers)

                    saver = tf.train.Saver()
                    with tf.Session() as sesh:
                        saver.restore(sesh, model_fn)

                        for episode in tqdm(
                                episode_numbers,
                                desc=
                                f'\t{k}-shot unimodal speech classification tests on {len(episode_numbers)} episodes for random seed {rnd_seed:2.0f}',
                                ncols=COL_LENGTH):

                            episode_num = str(episode)
                            query = episode_dict[episode_num]["query"]
                            query_data, query_keys, query_lab = generate_unimodal_speech_episodes.episode_data(
                                query["keys"], data_x, data_keys, data_labels)
                            query_iterator = batching_library.speech_iterator(
                                query_data,
                                len(query_data),
                                shuffle_batches_every_epoch=False)
                            query_labels = [
                                query_lab[i] for i in query_iterator.indices
                            ]

                            support_set = episode_dict[episode_num][
                                "support_set"]
                            S_data, S_keys, S_lab = generate_unimodal_speech_episodes.episode_data(
                                support_set["keys"], data_x, data_keys,
                                data_labels)
                            S_iterator = batching_library.speech_iterator(
                                S_data,
                                len(S_data),
                                shuffle_batches_every_epoch=False)
                            S_labels = [S_lab[i] for i in S_iterator.indices]

                            for feats, lengths in query_iterator:
                                lat = sesh.run(
                                    [speech_latent],
                                    feed_dict={
                                        speech_X: feats,
                                        speech_X_lengths: lengths,
                                        train_flag: False
                                    })[0]

                            for feats, lengths in S_iterator:
                                S_lat = sesh.run(
                                    [speech_latent],
                                    feed_dict={
                                        speech_X: feats,
                                        speech_X_lengths: lengths,
                                        train_flag: False
                                    })[0]

                            if normalize:
                                latents = (lat -
                                           lat.mean(axis=0)) / lat.std(axis=0)
                                s_latents = (S_lat - S_lat.mean(axis=0)
                                             ) / S_lat.std(axis=0)
                            else:
                                latents = lat
                                s_latents = S_lat

                            distances = cdist(latents, s_latents, "cosine")

                            indexes = np.argmin(distances, axis=1)
                            label_matches = few_shot_learning_library.label_matches_grid_generation_2D(
                                query_labels, S_labels)

                            for i in range(len(indexes)):
                                total += 1
                                if label_matches[i, indexes[i]]:
                                    correct += 1

                    return correct / total