def one_shot_validation(episode_file=lib["validation_episode_list"], data_x=val_x, data_keys=val_keys, data_labels=val_labels, normalize=True, print_normalization=False): episode_dict = generate_unimodal_image_episodes.read_in_episodes(episode_file) correct = 0 total = 0 episode_numbers = np.arange(1, len(episode_dict)+1) np.random.shuffle(episode_numbers) saver = tf.train.Saver() with tf.Session() as sesh: saver.restore(sesh, model_fn) for episode in episode_numbers: episode_num = str(episode) query = episode_dict[episode_num]["query"] query_data, query_keys, query_lab = generate_unimodal_image_episodes.episode_data( query["keys"], data_x, data_keys, data_labels ) query_iterator = batching_library.unflattened_image_iterator( query_data, len(query_data), shuffle_batches_every_epoch=False ) query_labels = [query_lab[i] for i in query_iterator.indices] support_set = episode_dict[episode_num]["support_set"] S_data, S_keys, S_lab = generate_unimodal_image_episodes.episode_data( support_set["keys"], data_x, data_keys, data_labels ) S_iterator = batching_library.unflattened_image_iterator( S_data, len(S_data), shuffle_batches_every_epoch=False ) S_labels = [S_lab[i] for i in S_iterator.indices] for feats in query_iterator: lat = sesh.run( [latent], feed_dict={X: feats} )[0] for feats in S_iterator: S_lat = sesh.run( [latent], feed_dict={X: feats} )[0] if normalize: latents = (lat - lat.mean(axis=0))/lat.std(axis=0) s_latents = (S_lat - S_lat.mean(axis=0))/S_lat.std(axis=0) if print_normalization: evaluation_library.normalization_visualization( lat, latents, labels, 300, path.join(lib["output_fn"], lib["model_name"]) ) else: latents = lat s_latents = S_lat distances = cdist(latents, s_latents, "cosine") indexes = np.argmin(distances, axis=1) label_matches = few_shot_learning_library.label_matches_grid_generation_2D(query_labels, S_labels) for i in range(len(indexes)): total += 1 if label_matches[i, indexes[i]]: correct += 1 return [-correct/total]
def unimodal_image_test(episode_fn=lib["image_episode_1_list"], data_x=image_test_x, data_keys=image_test_keys, data_labels=image_test_labels, normalize=True, k=1): episode_dict = generate_unimodal_image_episodes.read_in_episodes( episode_fn) correct = 0 total = 0 episode_numbers = np.arange(1, len(episode_dict) + 1) np.random.shuffle(episode_numbers) saver = tf.train.Saver() with tf.Session() as sesh: saver.restore(sesh, model_fn) for episode in tqdm( episode_numbers, desc= f'\t{k}-shot unimodal image classification tests on {len(episode_numbers)} episodes for random seed {rnd_seed:2.0f}', ncols=COL_LENGTH): episode_num = str(episode) query = episode_dict[episode_num]["query"] query_data, query_keys, query_lab = generate_unimodal_image_episodes.episode_data( query["keys"], data_x, data_keys, data_labels) query_iterator = batching_library.unflattened_image_iterator( query_data, len(query_data), shuffle_batches_every_epoch=False) query_labels = [ query_lab[i] for i in query_iterator.indices ] support_set = episode_dict[episode_num][ "support_set"] S_data, S_keys, S_lab = generate_unimodal_image_episodes.episode_data( support_set["keys"], data_x, data_keys, data_labels) S_iterator = batching_library.unflattened_image_iterator( S_data, len(S_data), shuffle_batches_every_epoch=False) S_labels = [S_lab[i] for i in S_iterator.indices] for feats in query_iterator: lat = sesh.run([image_latent], feed_dict={ image_X: feats, train_flag: False })[0] for feats in S_iterator: S_lat = sesh.run([image_latent], feed_dict={ image_X: feats, train_flag: False })[0] if normalize: latents = (lat - lat.mean(axis=0)) / lat.std(axis=0) s_latents = (S_lat - S_lat.mean(axis=0) ) / S_lat.std(axis=0) else: latents = lat s_latents = S_lat distances = cdist(latents, s_latents, "cosine") indexes = np.argmin(distances, axis=1) label_matches = few_shot_learning_library.label_matches_grid_generation_2D( query_labels, S_labels) for i in range(len(indexes)): total += 1 if label_matches[i, indexes[i]]: correct += 1 return correct / total