def image_cnn_latent_values(lib, feats_fn, latent_dict): x, labels, keys = (data_library.load_image_data_from_npz(feats_fn)) iterator = batching_library.unflattened_image_iterator( x, 1, shuffle_batches_every_epoch=False) indices = iterator.indices tf.reset_default_graph() config = tf.ConfigProto() config.gpu_options.allow_growth = True key_counter = 0 X = tf.placeholder(tf.float32, [None, 28, 28, 1]) target = tf.placeholder(tf.float32, [None, lib["num_classes"]]) train_flag = tf.placeholder_with_default(False, shape=()) model = model_legos_library.cnn_classifier_architecture( X, train_flag, lib["enc"], lib["enc_strides"], model_setup_library.pooling_lib(), lib["pool_layers"], lib["latent"], lib, model_setup_library.activation_lib(), print_layer=True) output = model["output"] latent = model["latent"] model_fn = model_setup_library.get_model_fn(lib) saver = tf.train.Saver() with tf.Session(config=config) as sesh: saver.restore(sesh, model_fn) for feats in tqdm(iterator, desc="Extracting latents", ncols=COL_LENGTH): lat = sesh.run(latent, feed_dict={X: feats, train_flag: False}) latent_dict[keys[indices[key_counter]]] = lat key_counter += 1 print("Total number of keys: {}".format(key_counter))
def one_shot_validation(episode_file=lib["validation_episode_list"], data_x=val_x, data_keys=val_keys, data_labels=val_labels, normalize=True, print_normalization=False): episode_dict = generate_unimodal_image_episodes.read_in_episodes(episode_file) correct = 0 total = 0 episode_numbers = np.arange(1, len(episode_dict)+1) np.random.shuffle(episode_numbers) saver = tf.train.Saver() with tf.Session() as sesh: saver.restore(sesh, model_fn) for episode in episode_numbers: episode_num = str(episode) query = episode_dict[episode_num]["query"] query_data, query_keys, query_lab = generate_unimodal_image_episodes.episode_data( query["keys"], data_x, data_keys, data_labels ) query_iterator = batching_library.unflattened_image_iterator( query_data, len(query_data), shuffle_batches_every_epoch=False ) query_labels = [query_lab[i] for i in query_iterator.indices] support_set = episode_dict[episode_num]["support_set"] S_data, S_keys, S_lab = generate_unimodal_image_episodes.episode_data( support_set["keys"], data_x, data_keys, data_labels ) S_iterator = batching_library.unflattened_image_iterator( S_data, len(S_data), shuffle_batches_every_epoch=False ) S_labels = [S_lab[i] for i in S_iterator.indices] for feats in query_iterator: lat = sesh.run( [latent], feed_dict={X: feats} )[0] for feats in S_iterator: S_lat = sesh.run( [latent], feed_dict={X: feats} )[0] if normalize: latents = (lat - lat.mean(axis=0))/lat.std(axis=0) s_latents = (S_lat - S_lat.mean(axis=0))/S_lat.std(axis=0) if print_normalization: evaluation_library.normalization_visualization( lat, latents, labels, 300, path.join(lib["output_fn"], lib["model_name"]) ) else: latents = lat s_latents = S_lat distances = cdist(latents, s_latents, "cosine") indexes = np.argmin(distances, axis=1) label_matches = few_shot_learning_library.label_matches_grid_generation_2D(query_labels, S_labels) for i in range(len(indexes)): total += 1 if label_matches[i, indexes[i]]: correct += 1 return [-correct/total]
def zero_shot_multimodal_task(episode_fn, sp_test_x, sp_test_keys, sp_test_labels, im_test_x, im_test_keys, im_test_labels, normalize=True): episode_dict = generate_episodes.read_in_episodes( episode_fn) np.random.seed(rnd_seed) episode_numbers = np.arange(1, len(episode_dict) + 1) np.random.shuffle(episode_numbers) correct = 0 total = 0 saver = tf.train.Saver() with tf.Session() as sesh: saver.restore(sesh, model_fn) for episode in tqdm( episode_numbers, desc= f'\t{0}-shot multimodal matching tests on {len(episode_numbers)} episodes for random seed {rnd_seed:2.0f}', ncols=COL_LENGTH): episode_num = str(episode) # Get query iterator query = episode_dict[episode_num]["query"] query_data, query_keys, query_lab = generate_episodes.episode_data( query["keys"], sp_test_x, sp_test_keys, sp_test_labels) query_iterator = batching_library.speech_iterator( query_data, len(query_data), shuffle_batches_every_epoch=False) matching_set = episode_dict[episode_num][ "matching_set"] M_data, M_keys, M_lab = generate_episodes.episode_data( matching_set["keys"], im_test_x, im_test_keys, im_test_labels) M_image_iterator = batching_library.unflattened_image_iterator( M_data, len(M_data), shuffle_batches_every_epoch=False) for feats, lengths in query_iterator: lat = sesh.run( [speech_latent], feed_dict={ speech_X: feats, speech_X_lengths: lengths, train_flag: False })[0] for feats in M_image_iterator: S_lat = sesh.run([image_latent], feed_dict={ image_X: feats, train_flag: False })[0] if normalize: query_latents = ( lat - lat.mean(axis=0)) / lat.std(axis=0) matching_latents = (S_lat - S_lat.mean(axis=0) ) / S_lat.std(axis=0) else: query_latents = lat s_latents = S_lat query_latent_labels = [ query_lab[i] for i in query_iterator.indices ] query_latent_keys = [ query_keys[i] for i in query_iterator.indices ] matching_latent_labels = [ M_lab[i] for i in M_image_iterator.indices ] matching_latent_keys = [ M_keys[i] for i in M_image_iterator.indices ] distances = cdist(query_latents, matching_latents, "cosine") indexes = np.argmin(distances, axis=1) label_matches = few_shot_learning_library.label_matches_grid_generation_2D( query_latent_labels, matching_latent_labels) for i in range(len(indexes)): total += 1 if label_matches[i, indexes[i]]: correct += 1 return correct / total
def multimodal_task(episode_fn, sp_test_x, sp_test_keys, sp_test_labels, im_test_x, im_test_keys, im_test_labels, normalize=True, k=1): episode_dict = generate_episodes.read_in_episodes( episode_fn) np.random.seed(rnd_seed) episode_numbers = np.arange(1, len(episode_dict) + 1) np.random.shuffle(episode_numbers) correct = 0 total = 0 saver = tf.train.Saver() with tf.Session() as sesh: saver.restore(sesh, model_fn) for episode in tqdm( episode_numbers, desc= f'\t{k}-shot multimodal matching tests on {len(episode_numbers)} episodes for random seed {rnd_seed:2.0f}', ncols=COL_LENGTH): episode_num = str(episode) # Get query iterator query = episode_dict[episode_num]["query"] query_data, query_keys, query_lab = generate_episodes.episode_data( query["keys"], sp_test_x, sp_test_keys, sp_test_labels) query_iterator = batching_library.speech_iterator( query_data, len(query_data), shuffle_batches_every_epoch=False) # Get speech_support set support_set = episode_dict[episode_num][ "support_set"] S_image_data, S_image_keys, S_image_lab = generate_episodes.episode_data( support_set["image_keys"], im_test_x, im_test_keys, im_test_labels) S_speech_data, S_speech_keys, S_speech_lab = generate_episodes.episode_data( support_set["speech_keys"], sp_test_x, sp_test_keys, sp_test_labels) key_list = [] for i in range(len(S_speech_keys)): key_list.append( (S_speech_keys[i], S_image_keys[i])) S_speech_iterator = batching_library.speech_iterator( S_speech_data, len(S_speech_data), shuffle_batches_every_epoch=False) for feats, lengths in query_iterator: lat = sesh.run( [speech_latent], feed_dict={ speech_X: feats, speech_X_lengths: lengths, train_flag: False })[0] for feats, lengths in S_speech_iterator: S_lat = sesh.run( [speech_latent], feed_dict={ speech_X: feats, speech_X_lengths: lengths, train_flag: False })[0] if normalize: query_latents = ( lat - lat.mean(axis=0)) / lat.std(axis=0) s_speech_latents = (S_lat - S_lat.mean(axis=0) ) / S_lat.std(axis=0) else: query_latents = lat s_speech_latents = S_lat query_latent_labels = [ query_lab[i] for i in query_iterator.indices ] query_latent_keys = [ query_keys[i] for i in query_iterator.indices ] s_speech_latent_labels = [ S_speech_lab[i] for i in S_speech_iterator.indices ] s_speech_latent_keys = [ S_speech_keys[i] for i in S_speech_iterator.indices ] distances1 = cdist(query_latents, s_speech_latents, "cosine") indexes1 = np.argmin(distances1, axis=1) chosen_speech_keys = [] for i in range(len(indexes1)): chosen_speech_keys.append( s_speech_latent_keys[indexes1[i]]) S_image_iterator = batching_library.unflattened_image_iterator( S_image_data, len(S_image_data), shuffle_batches_every_epoch=False) matching_set = episode_dict[episode_num][ "matching_set"] M_data, M_keys, M_lab = generate_episodes.episode_data( matching_set["keys"], im_test_x, im_test_keys, im_test_labels) M_image_iterator = batching_library.unflattened_image_iterator( M_data, len(M_data), shuffle_batches_every_epoch=False) for feats in S_image_iterator: lat = sesh.run([image_latent], feed_dict={ image_X: feats, train_flag: False })[0] for feats in M_image_iterator: S_lat = sesh.run([image_latent], feed_dict={ image_X: feats, train_flag: False })[0] if normalize: s_image_latents = ( lat - lat.mean(axis=0)) / lat.std(axis=0) matching_latents = (S_lat - S_lat.mean(axis=0) ) / S_lat.std(axis=0) else: s_image_latents = lat s_latents = S_lat s_image_latent_labels = [ S_image_lab[i] for i in S_image_iterator.indices ] s_image_latent_keys = [ S_image_keys[i] for i in S_image_iterator.indices ] matching_latent_labels = [ M_lab[i] for i in M_image_iterator.indices ] matching_latent_keys = [ M_keys[i] for i in M_image_iterator.indices ] image_key_order_list = [] s_image_latents_in_order = np.empty( (query_latents.shape[0], s_image_latents.shape[1])) s_image_labels_in_order = [] for j, key in enumerate(chosen_speech_keys): for (sp_key, im_key) in key_list: if key == sp_key: image_key_order_list.append(im_key) i = np.where( np.asarray(s_image_latent_keys) == im_key)[0][0] s_image_latents_in_order[ j:j + 1, :] = s_image_latents[i:i + 1, :] s_image_labels_in_order.append( s_image_latent_labels[i]) break distances2 = cdist(s_image_latents_in_order, matching_latents, "cosine") indexes2 = np.argmin(distances2, axis=1) label_matches = few_shot_learning_library.label_matches_grid_generation_2D( query_latent_labels, matching_latent_labels) for i in range(len(indexes2)): total += 1 if label_matches[i, indexes2[i]]: correct += 1 return correct / total
def unimodal_image_test(episode_fn=lib["image_episode_1_list"], data_x=image_test_x, data_keys=image_test_keys, data_labels=image_test_labels, normalize=True, k=1): episode_dict = generate_unimodal_image_episodes.read_in_episodes( episode_fn) correct = 0 total = 0 episode_numbers = np.arange(1, len(episode_dict) + 1) np.random.shuffle(episode_numbers) saver = tf.train.Saver() with tf.Session() as sesh: saver.restore(sesh, model_fn) for episode in tqdm( episode_numbers, desc= f'\t{k}-shot unimodal image classification tests on {len(episode_numbers)} episodes for random seed {rnd_seed:2.0f}', ncols=COL_LENGTH): episode_num = str(episode) query = episode_dict[episode_num]["query"] query_data, query_keys, query_lab = generate_unimodal_image_episodes.episode_data( query["keys"], data_x, data_keys, data_labels) query_iterator = batching_library.unflattened_image_iterator( query_data, len(query_data), shuffle_batches_every_epoch=False) query_labels = [ query_lab[i] for i in query_iterator.indices ] support_set = episode_dict[episode_num][ "support_set"] S_data, S_keys, S_lab = generate_unimodal_image_episodes.episode_data( support_set["keys"], data_x, data_keys, data_labels) S_iterator = batching_library.unflattened_image_iterator( S_data, len(S_data), shuffle_batches_every_epoch=False) S_labels = [S_lab[i] for i in S_iterator.indices] for feats in query_iterator: lat = sesh.run([image_latent], feed_dict={ image_X: feats, train_flag: False })[0] for feats in S_iterator: S_lat = sesh.run([image_latent], feed_dict={ image_X: feats, train_flag: False })[0] if normalize: latents = (lat - lat.mean(axis=0)) / lat.std(axis=0) s_latents = (S_lat - S_lat.mean(axis=0) ) / S_lat.std(axis=0) else: latents = lat s_latents = S_lat distances = cdist(latents, s_latents, "cosine") indexes = np.argmin(distances, axis=1) label_matches = few_shot_learning_library.label_matches_grid_generation_2D( query_labels, S_labels) for i in range(len(indexes)): total += 1 if label_matches[i, indexes[i]]: correct += 1 return correct / total