def one_shot_validation(episode_file=lib["validation_episode_list"], data_x=val_x, data_keys=val_keys, data_labels=val_labels, normalize=True, print_normalization=False):

        episode_dict = generate_unimodal_image_episodes.read_in_episodes(episode_file)
        correct = 0
        total = 0

        episode_numbers = np.arange(1, len(episode_dict)+1)
        np.random.shuffle(episode_numbers)

        saver = tf.train.Saver()
        with tf.Session() as sesh:
            saver.restore(sesh, model_fn)

            for episode in episode_numbers:

                episode_num = str(episode)
                query = episode_dict[episode_num]["query"]
                query_data, query_keys, query_lab = generate_unimodal_image_episodes.episode_data(
                    query["keys"], data_x, data_keys, data_labels
                    )
                query_iterator = batching_library.unflattened_image_iterator(
                    query_data, len(query_data), shuffle_batches_every_epoch=False
                    )
                query_labels = [query_lab[i] for i in query_iterator.indices]

                support_set = episode_dict[episode_num]["support_set"]
                S_data, S_keys, S_lab = generate_unimodal_image_episodes.episode_data(
                    support_set["keys"], data_x, data_keys, data_labels
                    )
                S_iterator = batching_library.unflattened_image_iterator(
                    S_data, len(S_data), shuffle_batches_every_epoch=False
                    )
                S_labels = [S_lab[i] for i in S_iterator.indices]


                for feats in query_iterator:
                    lat = sesh.run(
                        [latent], feed_dict={X: feats}
                        )[0]

                for feats in S_iterator:
                    S_lat = sesh.run(
                        [latent], feed_dict={X: feats}
                        )[0]

                if normalize: 
                    latents = (lat - lat.mean(axis=0))/lat.std(axis=0)
                    s_latents = (S_lat - S_lat.mean(axis=0))/S_lat.std(axis=0)
                    if print_normalization: 
                        evaluation_library.normalization_visualization(
                            lat, latents, labels, 300, 
                            path.join(lib["output_fn"], lib["model_name"])
                            )
                else: 
                    latents = lat
                    s_latents = S_lat

                distances = cdist(latents, s_latents, "cosine")
                indexes = np.argmin(distances, axis=1)
                label_matches = few_shot_learning_library.label_matches_grid_generation_2D(query_labels, S_labels)

                for i in range(len(indexes)):
                    total += 1
                    if label_matches[i, indexes[i]]:
                        correct += 1

        return [-correct/total]
示例#2
0
                def unimodal_image_test(episode_fn=lib["image_episode_1_list"],
                                        data_x=image_test_x,
                                        data_keys=image_test_keys,
                                        data_labels=image_test_labels,
                                        normalize=True,
                                        k=1):

                    episode_dict = generate_unimodal_image_episodes.read_in_episodes(
                        episode_fn)
                    correct = 0
                    total = 0

                    episode_numbers = np.arange(1, len(episode_dict) + 1)
                    np.random.shuffle(episode_numbers)

                    saver = tf.train.Saver()
                    with tf.Session() as sesh:
                        saver.restore(sesh, model_fn)

                        for episode in tqdm(
                                episode_numbers,
                                desc=
                                f'\t{k}-shot unimodal image classification tests on {len(episode_numbers)} episodes for random seed {rnd_seed:2.0f}',
                                ncols=COL_LENGTH):

                            episode_num = str(episode)
                            query = episode_dict[episode_num]["query"]
                            query_data, query_keys, query_lab = generate_unimodal_image_episodes.episode_data(
                                query["keys"], data_x, data_keys, data_labels)
                            query_iterator = batching_library.unflattened_image_iterator(
                                query_data,
                                len(query_data),
                                shuffle_batches_every_epoch=False)
                            query_labels = [
                                query_lab[i] for i in query_iterator.indices
                            ]

                            support_set = episode_dict[episode_num][
                                "support_set"]
                            S_data, S_keys, S_lab = generate_unimodal_image_episodes.episode_data(
                                support_set["keys"], data_x, data_keys,
                                data_labels)
                            S_iterator = batching_library.unflattened_image_iterator(
                                S_data,
                                len(S_data),
                                shuffle_batches_every_epoch=False)
                            S_labels = [S_lab[i] for i in S_iterator.indices]

                            for feats in query_iterator:
                                lat = sesh.run([image_latent],
                                               feed_dict={
                                                   image_X: feats,
                                                   train_flag: False
                                               })[0]

                            for feats in S_iterator:
                                S_lat = sesh.run([image_latent],
                                                 feed_dict={
                                                     image_X: feats,
                                                     train_flag: False
                                                 })[0]

                            if normalize:
                                latents = (lat -
                                           lat.mean(axis=0)) / lat.std(axis=0)
                                s_latents = (S_lat - S_lat.mean(axis=0)
                                             ) / S_lat.std(axis=0)
                            else:
                                latents = lat
                                s_latents = S_lat

                            distances = cdist(latents, s_latents, "cosine")
                            indexes = np.argmin(distances, axis=1)
                            label_matches = few_shot_learning_library.label_matches_grid_generation_2D(
                                query_labels, S_labels)

                            for i in range(len(indexes)):
                                total += 1
                                if label_matches[i, indexes[i]]:
                                    correct += 1

                    return correct / total