def train_few_shot_model( train_iterator, train_feed_dict, train_flag, train_loss, train_metrics, train_optimizer, n_epochs, max_batches, val_iterator, val_feed_dict, model_embedding, embed_input, query_input, support_memory_input, nearest_neighbour, n_episodes, log_interval=1, # Number of batches to complete between logging model_dir='saved_models', # Saved models written to <model_dir>/checkpoints summary_dir='summaries/train', # Directory for writing summaries save_filename='trained_model', # Checkpoint filename restore_checkpoint=None # Resumes training from a specific checkpoint ): # Get the global step global_step = tf.train.get_or_create_global_step() # Get tf.summary tensors to evaluate summaries = tf.summary.merge_all() val_acc_input = tf.placeholder(TF_FLOAT) val_summary = tf.summary.scalar('val_few_shot_accuracy', val_acc_input) # Define variables to store the best validation accuracy and epoch epoch_var = tf.Variable(0, name='best_epoch') accuracy_var = tf.Variable(0., name='best_accuracy') # Define a saver for model checkpoints checkpoint_saver = tf.train.Saver(max_to_keep=5, save_relative_paths=True) best_saver = tf.train.Saver(save_relative_paths=True) # Define helper function for saving model and parameters def _save_checkpoint(epoch, save_best=False, save_first=False): checkpoint_path = os.path.join(model_dir, 'checkpoints', save_filename) saved = checkpoint_saver.save(sess, checkpoint_path, global_step=epoch) logging.info("Saved model checkpoint to file (epoch {}): {}" "".format(epoch, saved)) if save_best: best_path = os.path.join(model_dir, 'final_model', save_filename) best_saved = best_saver.save(sess, best_path) logging.info("Saved new best model to file (epoch {}): {}" "".format(epoch, best_saved)) return best_saved elif save_first: first_path = os.path.join(model_dir, 'checkpoints', 'initial_model') first_saved = best_saver.save(sess, first_path) logging.info("Saved randomly initialized base model to file: {}" "".format(first_saved)) run_options = tf.RunOptions(report_tensor_allocations_upon_oom=True) config = tf.ConfigProto() config.gpu_options.allow_growth = True # allow gpu memory growth # Start tf.Session to train and validate model with tf.Session(config=config) as sess: # ---------------------------------------------------------------------- # Load model and log some debug info: # ---------------------------------------------------------------------- try: # restore from model checkpoint if restore_checkpoint is not None: # use specific checkpoint restore_path = os.path.join(model_dir, 'checkpoints', restore_checkpoint) if not os.path.isfile('{}.index'.format(restore_path)): restore_path = restore_checkpoint # possibly full path? else: # use latest checkpoint if available restore_path = tf.train.latest_checkpoint( os.path.join(model_dir, 'checkpoints')) checkpoint_saver.restore(sess, restore_path) logging.info( "Model restored from checkpoint: {}".format(restore_path)) start_epoch = int(restore_path.split('-')[-1]) except ValueError: # no checkpoints, initialize variables from scratch sess.run(tf.global_variables_initializer()) start_epoch = 0 _save_checkpoint(start_epoch, save_first=True) # reproducibility # Evaluate starting global step, and previous best accuracy and epoch step = sess.run(global_step) logging.info("Training from: Epoch: {}\tGlobal Step: {}".format( start_epoch + 1, step)) best_epoch, best_val_acc = sess.run([epoch_var, accuracy_var]) logging.info("Current best model: Epoch: {}\tValidation accuracy: " "{:.5f}".format(best_epoch + 1, best_val_acc)) # Create session summary writer summary_writer = tf.summary.FileWriter( os.path.join(model_dir, summary_dir, datetime.datetime.now().strftime("%Hh%Mm%Ss_%f")), sess.graph) # TODO(rpeloff) If siamese ... # Get some triplet pairs, and display on tensorboard # x_triplet, y_triplet = train_iterator.get_next() # sess.run(train_iterator.initializer, feed_dict=train_feed_dict) # init validation set iterator # anch_summ = tf.summary.image('triplet_anchor_images', x_triplet[0], 5) # same_summ = tf.summary.image('triplet_same_images', x_triplet[1], 5) # diff_summ = tf.summary.image('triplet_different_images', x_triplet[2], 5) # x_trip_batch, y_trip_batch, anch_images, same_images, diff_images = sess.run( # [x_triplet, y_triplet, anch_summ, same_summ, diff_summ]) # summary_writer.add_summary(anch_images, step) # summary_writer.add_summary(same_images, step) # summary_writer.add_summary(diff_images, step) # summary_writer.flush() # Get support/query few-shot set, and display one episode on tensorboard support_set, query_set = val_iterator.get_next() sess.run(val_iterator.initializer, feed_dict=val_feed_dict) # init validation set iterator s_summ = tf.summary.image('support_set_images', support_set[0], 10) q_summ = tf.summary.image('query_set_images', query_set[0], 10) support_batch, query_batch, s_images, q_images = sess.run( [support_set, query_set, s_summ, q_summ]) summary_writer.add_summary(s_images, step) summary_writer.add_summary(q_images, step) summary_writer.flush() # Save figures to pdf for later use ... for image, label, speaker in zip(*support_batch): utils.save_image(np.squeeze(image, axis=-1), filename=os.path.join( model_dir, 'train_images', '{}_{}_{}_{}_{}.pdf'.format( 'support', 'l', label.decode("utf-8"), 's', speaker.decode("utf-8"))), cmap='inferno') for image, label, speaker in zip(*query_batch): utils.save_image(np.squeeze(image, axis=-1), filename=os.path.join( model_dir, 'train_images', '{}_{}_{}_{}_{}.pdf'.format( 'query', 'l', label.decode("utf-8"), 's', speaker.decode("utf-8"))), cmap='inferno') # ---------------------------------------------------------------------- # Training: # ---------------------------------------------------------------------- for epoch in range(start_epoch, n_epochs): logging.info("Epoch: {}/{} [Step: {}]" "".format(epoch + 1, n_epochs, step)) sess.run(train_iterator.initializer, feed_dict=train_feed_dict) # init train dataset iterator avg_loss = 0. n_batches_completed = 0 for i in range(max_batches): try: # TODO(reploff) Add embeddings and labels to visualize how # embeddings change over training? _, loss_val, summary_vals, metric_vals, step = sess.run( [ train_optimizer, train_loss, summaries, [m for m in train_metrics.values()], global_step ], feed_dict={train_flag: True}) # Write summaries for tensorboard and log some info summary_writer.add_summary(summary_vals, step) summary_writer.flush() n_batches_completed += 1 avg_loss += loss_val if n_batches_completed % log_interval == 0: batch_message = ( "\tTrain: [Batch: {}/{}]\tLoss: {:.7f}" "".format(n_batches_completed, max_batches, loss_val)) for metric_key, metric_val in zip( [k for k in train_metrics.keys()], metric_vals): batch_message += "\t{}: {}".format( metric_key, metric_val) logging.info(batch_message) except tf.errors.OutOfRangeError: # catch pipeline out of range break # ------------------------------------------------------------------ # Few-shot validation: # ------------------------------------------------------------------ total_queries = 0 total_correct = 0 sess.run(val_iterator.initializer, feed_dict=val_feed_dict) # init validation set iterator for episode in range(n_episodes): support_batch, query_batch = sess.run([support_set, query_set]) # Get embeddings and classify queries with 1-NN on support set support_embeddings = sess.run(model_embedding, feed_dict={ embed_input: support_batch[0], train_flag: False }) query_embeddings = sess.run(model_embedding, feed_dict={ embed_input: query_batch[0], train_flag: False }) nearest_neighbour_indices = sess.run(nearest_neighbour, feed_dict={ query_input: query_embeddings, support_memory_input: support_embeddings }) # Calculate and store number of correct predictions predicted_labels = support_batch[1][nearest_neighbour_indices] total_correct += np.sum(query_batch[1] == predicted_labels) total_queries += query_batch[1].shape[0] if episode % int(n_episodes / 5) == 0: avg_acc = total_correct / total_queries ep_message = ("\tFew-shot Test: [Episode: {}/{}]\t" "Average accuracy: {:.7f}".format( episode, n_episodes, avg_acc)) logging.info(ep_message) # ------------------------------------------------------------------ # Print stats and early-stopping: # ------------------------------------------------------------------ # Print epoch train stats avg_loss = avg_loss / n_batches_completed epoch_message = ("Epoch: {}/{} [Step: {}]\tTrain set: Average " "loss: {:.5f}".format(epoch + 1, n_epochs, step, avg_loss)) logging.info(epoch_message) # Print epoch few-shot validation stats avg_acc = total_correct / total_queries few_shot_message = ("Epoch: {}/{} [Step: {}]\tValidation set (few-" "shot): Average accuracy: {:.5f}".format( epoch + 1, n_epochs, step, avg_acc)) logging.info(few_shot_message) val_summ = sess.run(val_summary, feed_dict={val_acc_input: avg_acc}) summary_writer.add_summary(val_summ, step) summary_writer.flush() # Check if this is a new best model if avg_acc > best_val_acc: best_epoch = epoch best_val_acc = avg_acc sess.run([ tf.assign(epoch_var, best_epoch), tf.assign(accuracy_var, best_val_acc) ]) _save_checkpoint(epoch + 1, save_best=True) # save best model with open(os.path.join(model_dir, 'train_result.txt'), 'w') as res_file: res_file.write("Epoch: {}\tTrain loss: {:.5f}\t" "Validation accuracy: {:.5f}".format( epoch + 1, avg_loss, avg_acc)) # Training complete, print final (best) model stats: logging.info("Training complete. Best model found at epoch {} with " "validation accuracy {:.5f}.".format( best_epoch + 1, best_val_acc))
def test_mulitmodal_few_shot_model( test_feed_dict, # Speech test params ... speech_graph, speech_train_flag, speech_test_iterator, speech_model_embedding, speech_embed_input, # Vision test params ... vision_graph, vision_train_flag, vision_test_iterator, vision_model_embedding, vision_embed_input, # Nearest neigbour params ... query_input, support_memory_input, nearest_neighbour, n_episodes, query_type, test_pixels=False, test_dtw=False, dtw_cost_func=None, dtw_post_process=None, test_invariance=False, # Other params ... log_interval=1, model_dir='saved_models', speech_model_dir='saved_models/speech', vision_model_dir='saved_models/vision', summary_dir='summaries/test', speech_restore_checkpoint=None, vision_restore_checkpoint=None): # Create tf.Session's for speech, vision, and general models speech_session = tf.Session(graph=speech_graph) vision_session = tf.Session(graph=vision_graph) general_session = tf.Session() # default graph # Get model global steps with speech_graph.as_default(): speech_global_step = tf.train.get_or_create_global_step() speech_step = 0 with vision_graph.as_default(): vision_global_step = tf.train.get_or_create_global_step() vision_step = 0 # -------------------------------------------------------------------------- # Load models (unless using DTW or pixel matching) and log some debug info: # -------------------------------------------------------------------------- if not test_dtw: with speech_session.as_default(), speech_graph.as_default(): #pylint: disable=E1129 try: # restore speech from model checkpoint speech_checkpoint_saver = tf.train.Saver( save_relative_paths=True) if speech_restore_checkpoint is not None: # specific checkpoint restore_path = os.path.join(speech_model_dir, 'checkpoints', speech_restore_checkpoint) if not os.path.isfile('{}.index'.format(restore_path)): restore_path = speech_restore_checkpoint # full path else: # use best model if available final_model_dir = os.path.join(speech_model_dir, 'final_model') restore_path = tf.train.latest_checkpoint(final_model_dir) if restore_path is None: logging.info( "No best model checkpoint could be found " "in directory: {}".format(final_model_dir)) return # exit ... speech_checkpoint_saver.restore(speech_session, restore_path) logging.info( "Speech model restored from checkpoint: {}".format( restore_path)) except ValueError: # no checkpoints, inform and exit ... logging.info( "Vision model checkpoint could not found at restore " "path: {}".format(restore_path)) return # exit ... # Evaluate global speech model was trained to speech_step = speech_session.run(speech_global_step) logging.info("Testing speech model from: Global Step: {}" "".format(speech_step)) else: logging.info("Testing speech model with dynamic time warping (DTW).") if not test_pixels: with vision_session.as_default(), vision_graph.as_default(): #pylint: disable=E1129 try: # restore vision from model checkpoint vision_checkpoint_saver = tf.train.Saver( save_relative_paths=True) if vision_restore_checkpoint is not None: # specific checkpoint restore_path = os.path.join(vision_model_dir, 'checkpoints', vision_restore_checkpoint) if not os.path.isfile('{}.index'.format(restore_path)): restore_path = vision_restore_checkpoint # full path else: # use best model if available final_model_dir = os.path.join(vision_model_dir, 'final_model') restore_path = tf.train.latest_checkpoint(final_model_dir) if restore_path is None: logging.info( "No best model checkpoint could be found " "in directory: {}".format(final_model_dir)) return # exit ... vision_checkpoint_saver.restore(vision_session, restore_path) logging.info( "Vision model restored from checkpoint: {}".format( restore_path)) except ValueError: # no checkpoints, inform and exit ... logging.info( "Speech model checkpoint could not found at restore " "path: {}".format(restore_path)) return # exit ... # Evaluate global vision model was trained to vision_step = vision_session.run(vision_global_step) logging.info("Testing vision model from: Global Step: {}".format( vision_step)) else: logging.info("Testing vision model with pure pixel matching.") # Create general session summary writer and few-shot accuracy summary summary_writer = tf.summary.FileWriter( os.path.join(model_dir, summary_dir, datetime.datetime.now().strftime("%Hh%Mm%Ss_%f")), general_session.graph) # Get tf.summary tensor to evaluate for few-shot accuracy test_acc_input = tf.placeholder(TF_FLOAT) test_summ = tf.summary.scalar('test_few_shot_accuracy', test_acc_input) # Get speech and vision support/query/matching few-shot sets, and save # one episode to tensorboard and files for debugging speech_support_set, speech_query_set = speech_test_iterator.get_next() vision_support_set, vision_query_set = vision_test_iterator.get_next() general_session.run( [speech_test_iterator.initializer, vision_test_iterator.initializer], feed_dict=test_feed_dict) # init test set iterator speech_s_summ = tf.summary.image('speech_support_set_images', speech_support_set[0], 10) vision_s_summ = tf.summary.image('vision_support_set_images', vision_support_set[0], 10) if query_type == 'speech': # speech query, image matching set speech_q_summ = tf.summary.image('speech_query_set_images', speech_query_set[0], 10) vision_q_summ = tf.summary.image('vision_matching_set_images', vision_query_set[0], 10) else: # image query, speech matching set speech_q_summ = tf.summary.image('speech_matching_set_images', speech_query_set[0], 10) vision_q_summ = tf.summary.image('vision_query_set_images', vision_query_set[0], 10) (speech_s_batch, speech_q_batch, vision_s_batch, vision_q_batch, speech_s_images, speech_q_images, vision_s_images, vision_q_images) = (general_session.run([ speech_support_set, speech_query_set, vision_support_set, vision_query_set, speech_s_summ, speech_q_summ, vision_s_summ, vision_q_summ ])) # Save figures to pdf for later use ... for index, (image, label, speaker) in enumerate(zip(*speech_s_batch)): if test_dtw: image = dtw_post_process(image) else: image = np.squeeze(image, axis=-1) utils.save_image(image, filename=os.path.join( model_dir, 'multimodal_test_images', '{}_{}_{}_{}_{}_{}_{}.pdf'.format( 'speech', 'support', index, 'label', label.decode("utf-8"), 'speaker', speaker.decode("utf-8"))), cmap='inferno') for index, (image, label, speaker) in enumerate(zip(*speech_q_batch)): if test_dtw: image = dtw_post_process(image) else: image = np.squeeze(image, axis=-1) set_label = 'query' if query_type == 'speech' else 'matching' utils.save_image(image, filename=os.path.join( model_dir, 'multimodal_test_images', '{}_{}_{}_{}_{}_{}_{}.pdf'.format( 'speech', set_label, index, 'label', label.decode("utf-8"), 'speaker', speaker.decode("utf-8"))), cmap='inferno') for index, (image, label) in enumerate(zip(*vision_s_batch)): utils.save_image(np.squeeze(image, axis=-1), filename=os.path.join( model_dir, 'multimodal_test_images', '{}_{}_{}_{}_{}.pdf'.format( 'vision', 'support', index, 'label', label.decode("utf-8"))), cmap='gray_r') for index, (image, label) in enumerate(zip(*vision_q_batch)): set_label = 'query' if query_type == 'vision' else 'matching' utils.save_image(np.squeeze(image, axis=-1), filename=os.path.join( model_dir, 'multimodal_test_images', '{}_{}_{}_{}_{}.pdf'.format( 'vision', set_label, index, 'label', label.decode("utf-8"))), cmap='gray_r') # Save summary images with general summary writer summary_writer.add_summary(speech_s_images, speech_step) summary_writer.add_summary(speech_q_images, speech_step) summary_writer.add_summary(vision_s_images, vision_step) summary_writer.add_summary(vision_q_images, vision_step) summary_writer.flush() # -------------------------------------------------------------------------- # Cross-modal few-shot testing: # -------------------------------------------------------------------------- total_queries = 0 total_correct = 0 # Speaker invariance accuracy counters total_easy_queries = 0 total_easy_correct = 0 total_distractor_queries = 0 total_distractor_correct = 0 general_session.run( [speech_test_iterator.initializer, vision_test_iterator.initializer], feed_dict=test_feed_dict) # init test set iterator few_shot_set = [(speech_support_set, speech_query_set), (vision_support_set, vision_query_set)] for episode in range(n_episodes): # Get next few-shot batch episode_batch = general_session.run(few_shot_set) speech_batch = episode_batch[0] vision_batch = episode_batch[1] # Get embeddings and classify queries with 1-NN on support set with speech_session.as_default(), speech_graph.as_default(): #pylint: disable=E1129 speech_support_embeddings = speech_session.run( speech_model_embedding, feed_dict={ speech_embed_input: speech_batch[0][0], speech_train_flag: False }) speech_query_embeddings = speech_session.run( speech_model_embedding, feed_dict={ speech_embed_input: speech_batch[1][0], speech_train_flag: False }) with vision_session.as_default(), vision_graph.as_default(): #pylint: disable=E1129 vision_support_embeddings = vision_session.run( vision_model_embedding, feed_dict={ vision_embed_input: vision_batch[0][0], vision_train_flag: False }) vision_query_embeddings = vision_session.run( vision_model_embedding, feed_dict={ vision_embed_input: vision_batch[1][0], vision_train_flag: False }) # Speech query cross-modal matching to images if query_type == 'speech': pred_message = "" if not test_dtw: # test speech with fast cosine 1-NN memory model s_nearest_neighbour_indices = general_session.run( nearest_neighbour, feed_dict={ query_input: speech_query_embeddings, support_memory_input: speech_support_embeddings }) else: # test speech with dynamic time warping baseline costs = [[ dtw_cost_func( dtw_post_process(speech_query_embeddings[i]), dtw_post_process(speech_support_embeddings[j]), True) for j in range(len(speech_support_embeddings)) ] for i in range(len(speech_query_embeddings))] s_nearest_neighbour_indices = [ np.argmin(costs[i]) for i in range(len(costs)) ] # Get cross-modal matches in image support set and their labels query_cross_matches = vision_support_embeddings[ s_nearest_neighbour_indices] actual_labels = speech_batch[1][1] # Find cross-modal matches in the image matching set v_nearest_neighbour_indices = general_session.run( nearest_neighbour, feed_dict={ query_input: query_cross_matches, support_memory_input: vision_query_embeddings }) predicted_labels = vision_batch[1][1][v_nearest_neighbour_indices] # Images 'z' and 'o' treated same regardless of different speech classes actual_labels_update = np.array( [label if label != b'o' else b'z' for label in actual_labels]) predicted_labels_update = np.array([ label if label != b'o' else b'z' for label in predicted_labels ]) # Put together some debug info ... pred_message += "\t\tActual speech query labels:\t\t{}".format( actual_labels) pred_message += "\n\t\tPredicted speech support labels:\t{}".format( speech_batch[0][1][s_nearest_neighbour_indices]) pred_message += "\n\t\tAssociated vision support labels:\t{}".format( vision_batch[0][1][s_nearest_neighbour_indices]) pred_message += "\n\t\tPredicted vision matching labels:\t{}".format( predicted_labels) pred_message += "\n\t\tUpdated speech query labels ('o'=='z'):\t{}".format( actual_labels_update) pred_message += "\n\t\tUpdated vision match labels ('o'=='z'):\t{}".format( predicted_labels_update) # Update accuracy counters and log info total_correct += np.sum( actual_labels_update == predicted_labels_update) total_queries += speech_batch[1][1].shape[0] if test_invariance: # Count queries and predictions with easy/distractor speakers for q_index in range(speech_batch[1][1].shape[0]): n_same_speaker = np.sum( np.logical_and( speech_batch[1][1][q_index] == speech_batch[0][1], speech_batch[1][2][q_index] == speech_batch[0][2])) if n_same_speaker > 0: # easy speakers total_easy_queries += 1 if actual_labels_update[ q_index] == predicted_labels_update[q_index]: total_easy_correct += 1 else: # distractor speakers total_distractor_queries += 1 if actual_labels_update[ q_index] == predicted_labels_update[q_index]: total_distractor_correct += 1 if episode % log_interval == 0: avg_acc = total_correct / total_queries ep_message = ("\tFew-shot Test: [Episode: {}/{}]\t" "Average accuracy: {:.7f}".format( episode, n_episodes, avg_acc)) if test_invariance: avg_easy_acc = total_easy_correct / total_easy_queries if total_easy_queries != 0 else 0. avg_dist_acc = total_distractor_correct / total_distractor_queries if total_distractor_queries != 0 else 0. ep_message += ( "\n\t\tEasy speaker accuracy: {:.7f}\tDistractor " "speaker accuracy: {:.7f}".format( avg_easy_acc, avg_dist_acc)) ep_message += ( "\n\t\tNum easy speakers: {}\tNum distractor speakers: {}" .format(total_easy_queries, total_distractor_queries)) logging.info(ep_message) # print("Speech support labels:", speech_batch[0][1]) # print("Vision support labels:", vision_batch[0][1]) logging.info(pred_message) # Image query cross-modal matching to speech else: pred_message = "" v_nearest_neighbour_indices = general_session.run( nearest_neighbour, feed_dict={ query_input: vision_query_embeddings, support_memory_input: vision_support_embeddings }) # Get cross-modal matches in speech support set and their labels query_cross_matches = speech_support_embeddings[ v_nearest_neighbour_indices] actual_labels = vision_batch[1][1] # Find cross-modal matches in the speech matching set if not test_dtw: # test speech with fast cosine 1-NN memory model s_nearest_neighbour_indices = general_session.run( nearest_neighbour, feed_dict={ query_input: query_cross_matches, support_memory_input: speech_query_embeddings }) else: # test speech with dynamic time warping baseline costs = [[ dtw_cost_func(dtw_post_process(query_cross_matches[i]), dtw_post_process(speech_query_embeddings[j]), True) for j in range(len(speech_query_embeddings)) ] for i in range(len(query_cross_matches))] s_nearest_neighbour_indices = [ np.argmin(costs[i]) for i in range(len(costs)) ] predicted_labels = speech_batch[1][1][s_nearest_neighbour_indices] # Images 'z' and 'o' treated same regardless of different speech classes actual_labels_update = np.array( [label if label != b'o' else b'z' for label in actual_labels]) predicted_labels_update = np.array([ label if label != b'o' else b'z' for label in predicted_labels ]) # Put together some debug info ... pred_message += "\t\tActual vision query labels:\t\t{}".format( actual_labels) pred_message += "\n\t\tPredicted vision support labels:\t{}".format( vision_batch[0][1][v_nearest_neighbour_indices]) pred_message += "\n\t\tAssociated speech support labels:\t{}".format( speech_batch[0][1][v_nearest_neighbour_indices]) pred_message += '\n\t\tPredicted speech matching labels:\t{}'.format( predicted_labels) pred_message += "\n\t\tUpdated vision query labels:\t\t{}".format( actual_labels_update) pred_message += '\n\t\tUpdated speech prediction labels:\t{}'.format( predicted_labels_update) # Update accuracy counters and log info total_correct += np.sum( actual_labels_update == predicted_labels_update) total_queries += vision_batch[1][1].shape[0] if episode % log_interval == 0: avg_acc = total_correct / total_queries ep_message = ("\tFew-shot Test: [Episode: {}/{}]\t" "Average accuracy: {:.7f}".format( episode, n_episodes, avg_acc)) logging.info(ep_message) logging.info(pred_message) # ------------------------------------------------------------------ # Print stats: # ------------------------------------------------------------------ avg_acc = total_correct / total_queries few_shot_message = ("Test set (few-shot): Average accuracy: " "{:.5f}".format(avg_acc)) if test_invariance: avg_easy_acc = total_easy_correct / total_easy_queries avg_dist_acc = total_distractor_correct / total_distractor_queries few_shot_message += ("\n\t\tEasy speaker accuracy: {:.5f}\tDistractor " "speaker accuracy: {:.5f}".format( avg_easy_acc, avg_dist_acc)) few_shot_message += ( "\n\t\tNum easy speakers: {}\tNum distractor speakers: {}".format( total_easy_queries, total_distractor_queries)) logging.info(few_shot_message) test_summ = general_session.run(test_summ, feed_dict={test_acc_input: avg_acc}) summary_writer.add_summary(test_summ, 0) summary_writer.flush() with open(os.path.join(model_dir, 'test_result.txt'), 'w') as res_file: res_file.write(few_shot_message) # Testing complete logging.info("Testing complete.") speech_session.close() vision_session.close() general_session.close()
def test_few_shot_model( train_flag, test_iterator, test_feed_dict, model_embedding, embed_input, query_input, support_memory_input, nearest_neighbour, n_episodes, test_pixels=False, log_interval=1, model_dir='saved_models', output_dir='.', summary_dir='summaries/test', restore_checkpoint=None): # Get the global step tensor and set intial step value global_step = tf.train.get_or_create_global_step() step = 0 # Define a saver to load model checkpoint checkpoint_saver = tf.train.Saver(save_relative_paths=True) # Start tf.Session to test model with tf.Session() as sess: # ---------------------------------------------------------------------- # Load model (unless using pixels) and log some debug info: # ---------------------------------------------------------------------- if not test_pixels: try: # restore from model checkpoint if restore_checkpoint is not None: # use specific checkpoint restore_path = os.path.join( model_dir, 'checkpoints', restore_checkpoint) if not os.path.isfile('{}.index'.format(restore_path)): restore_path = restore_checkpoint # possibly full path? else: # use best model if available final_model_dir = os.path.join(model_dir, 'final_model') restore_path = tf.train.latest_checkpoint(final_model_dir) if restore_path is None: logging.info("No best model checkpoint could be found " "in directory: {}".format(final_model_dir)) return # exit ... checkpoint_saver.restore(sess, restore_path) logging.info("Model restored from checkpoint file: " "{}".format(restore_path)) except ValueError: # no checkpoints, inform and exit ... logging.info("Model checkpoint could not found at restore " "path: {}".format(restore_path)) return # exit ... # Evaluate global step model was trained to step = sess.run(global_step) logging.info("Testing from: Global Step: {}" .format(step)) else: logging.info("Testing vision model with pure pixel matching.") # Create session summary writer summary_writer = tf.summary.FileWriter(os.path.join( output_dir, summary_dir, datetime.datetime.now().strftime("%Hh%Mm%Ss_%f")), sess.graph) # Get tf.summary tensor to evaluate for few-shot accuracy test_acc_input = tf.placeholder(TF_FLOAT) test_summ = tf.summary.scalar('test_few_shot_accuracy', test_acc_input) # Get support/query few-shot set, and display one episode on tensorboard support_set, query_set = test_iterator.get_next() sess.run(test_iterator.initializer, feed_dict=test_feed_dict) s_summ = tf.summary.image('support_set_images', support_set[0], 10) q_summ = tf.summary.image('query_set_images', query_set[0], 10) support_batch, query_batch, s_images, q_images = sess.run( [support_set, query_set, s_summ, q_summ]) summary_writer.add_summary(s_images, step) summary_writer.add_summary(q_images, step) summary_writer.flush() # Save figures to pdf for later use ... for index, (image, label) in enumerate(zip(*support_batch)): utils.save_image(np.squeeze(image, axis=-1), filename=os.path.join( output_dir, 'test_images', '{}_{}_{}_{}.pdf'.format( 'support', index, 'label', label.decode("utf-8"))), cmap='gray_r') for index, (image, label) in enumerate(zip(*query_batch)): utils.save_image(np.squeeze(image, axis=-1), filename=os.path.join( output_dir, 'test_images', '{}_{}_{}_{}.pdf'.format( 'query', index, 'label', label.decode("utf-8"))), cmap='gray_r') # ---------------------------------------------------------------------- # Few-shot testing: # ---------------------------------------------------------------------- total_queries = 0 total_correct = 0 sess.run(test_iterator.initializer, feed_dict=test_feed_dict) for episode in range(n_episodes): support_batch, query_batch = sess.run([support_set, query_set]) # Get embeddings and classify queries with 1-NN on support set support_embeddings = sess.run(model_embedding, feed_dict={embed_input: support_batch[0], train_flag: False}) query_embeddings = sess.run(model_embedding, feed_dict={embed_input: query_batch[0], train_flag: False}) nearest_neighbour_indices = sess.run(nearest_neighbour, feed_dict={query_input: query_embeddings, support_memory_input: support_embeddings}) # Calculate and store number of correct predictions predicted_labels = support_batch[1][nearest_neighbour_indices] total_correct += np.sum(query_batch[1] == predicted_labels) total_queries += query_batch[1].shape[0] if episode % log_interval == 0: avg_acc = total_correct/total_queries ep_message = ("\tFew-shot Test: [Episode: {}/{}]\t" "Average accuracy: {:.7f}".format( episode, n_episodes, avg_acc)) logging.info(ep_message) # ---------------------------------------------------------------------- # Print stats: # ---------------------------------------------------------------------- avg_acc = total_correct/total_queries few_shot_message = ("Test set (few-shot): Average accuracy: " "{:.5f}".format(avg_acc)) logging.info(few_shot_message) test_summ = sess.run(test_summ, feed_dict={test_acc_input: avg_acc}) summary_writer.add_summary(test_summ, step) summary_writer.flush() with open(os.path.join(output_dir, 'test_result.txt'), 'w') as res_file: res_file.write(few_shot_message) # Testing complete logging.info("Testing complete.")
def test_few_shot_model(train_flag, test_iterator, test_feed_dict, model_embedding, embed_input, query_input, support_memory_input, nearest_neighbour, n_episodes, test_dtw=False, dtw_cost_func=None, dtw_post_process=None, test_invariance=False, log_interval=1, model_dir='saved_models', output_dir='.', summary_dir='summaries/test', restore_checkpoint=None): # Get the global step tensor and set intial step value global_step = tf.train.get_or_create_global_step() step = 0 # Define a saver to load model checkpoint checkpoint_saver = tf.train.Saver(save_relative_paths=True) # Start tf.Session to test model with tf.Session() as sess: # ---------------------------------------------------------------------- # Load model (unless using DTW) and log some debug info: # ---------------------------------------------------------------------- if not test_dtw: try: # restore from model checkpoint if restore_checkpoint is not None: # use specific checkpoint restore_path = os.path.join(model_dir, 'checkpoints', restore_checkpoint) if not os.path.isfile('{}.index'.format(restore_path)): restore_path = restore_checkpoint # possibly full path? else: # use best model if available final_model_dir = os.path.join(model_dir, 'final_model') restore_path = tf.train.latest_checkpoint(final_model_dir) if restore_path is None: logging.info( "No best model checkpoint could be found " "in directory: {}".format(final_model_dir)) return # exit ... checkpoint_saver.restore(sess, restore_path) logging.info("Model restored from checkpoint file: " "{}".format(restore_path)) except ValueError: # no checkpoints, inform and exit ... logging.info("Model checkpoint could not found at restore " "path: {}".format(restore_path)) return # exit ... # Evaluate global step model was trained to step = sess.run(global_step) logging.info("Testing from: Global Step: {}".format(step)) else: logging.info( "Testing speech model with dynamic time warping (DTW).") # Create session summary writer summary_writer = tf.summary.FileWriter( os.path.join(output_dir, summary_dir, datetime.datetime.now().strftime("%Hh%Mm%Ss_%f")), sess.graph) # Get tf.summary tensor to evaluate for few-shot accuracy test_acc_input = tf.placeholder(TF_FLOAT) test_summ = tf.summary.scalar('test_few_shot_accuracy', test_acc_input) # Get support/query few-shot set, and display one episode on tensorboard support_set, query_set = test_iterator.get_next() sess.run(test_iterator.initializer, feed_dict=test_feed_dict) s_summ = tf.summary.image('support_set_images', support_set[0], 10) q_summ = tf.summary.image('query_set_images', query_set[0], 10) support_batch, query_batch, s_images, q_images = sess.run( [support_set, query_set, s_summ, q_summ]) summary_writer.add_summary(s_images, step) summary_writer.add_summary(q_images, step) summary_writer.flush() # Save figures to pdf for later use ... for index, (image, label, speaker) in enumerate(zip(*support_batch)): if test_dtw: image = dtw_post_process(image) else: image = np.squeeze(image, axis=-1) utils.save_image(image, filename=os.path.join( output_dir, 'test_images', '{}_{}_{}_{}_{}_{}.pdf'.format( 'support', index, 'label', label.decode("utf-8"), 'speaker', speaker.decode("utf-8"))), cmap='inferno') for index, (image, label, speaker) in enumerate(zip(*query_batch)): if test_dtw: image = dtw_post_process(image) else: image = np.squeeze(image, axis=-1) utils.save_image(image, filename=os.path.join( output_dir, 'test_images', '{}_{}_{}_{}_{}_{}.pdf'.format( 'query', index, 'label', label.decode("utf-8"), 'speaker', speaker.decode("utf-8"))), cmap='inferno') # ---------------------------------------------------------------------- # Few-shot testing: # ---------------------------------------------------------------------- # Few-shot accuracy counters total_queries = 0 total_correct = 0 # Speaker invariance accuracy counters total_easy_queries = 0 total_easy_correct = 0 total_distractor_queries = 0 total_distractor_correct = 0 sess.run(test_iterator.initializer, feed_dict=test_feed_dict) for episode in range(n_episodes): support_batch, query_batch = sess.run([support_set, query_set]) # Get embeddings and classify queries with 1-NN on support set support_embeddings = sess.run(model_embedding, feed_dict={ embed_input: support_batch[0], train_flag: False }) query_embeddings = sess.run(model_embedding, feed_dict={ embed_input: query_batch[0], train_flag: False }) if not test_dtw: # test with fast cosine 1-NN memory model nearest_neighbour_indices = sess.run(nearest_neighbour, feed_dict={ query_input: query_embeddings, support_memory_input: support_embeddings }) else: # test with dynamic time warping costs = [[ dtw_cost_func(dtw_post_process(query_embeddings[i]), dtw_post_process(support_embeddings[j]), True) for j in range(len(support_embeddings)) ] for i in range(len(query_embeddings))] nearest_neighbour_indices = [ np.argmin(costs[i]) for i in range(len(costs)) ] # Calculate and store number of correct predictions predicted_labels = support_batch[1][nearest_neighbour_indices] total_correct += np.sum(query_batch[1] == predicted_labels) total_queries += query_batch[1].shape[0] if test_invariance: # Count queries and predictions with easy/distractor speakers for q_index in range(query_batch[1].shape[0]): n_same_speaker = np.sum( np.logical_and( query_batch[1][q_index] == support_batch[1], query_batch[2][q_index] == support_batch[2])) if n_same_speaker > 0: # easy speakers # print("Support set labels:", support_batch[1]) # print("Support set speaker:", support_batch[2]) # print("Query set labels:", query_batch[1]) # print("Query set speaker:", query_batch[2]) total_easy_queries += 1 if query_batch[1][q_index] == predicted_labels[ q_index]: total_easy_correct += 1 else: # distractor speakers total_distractor_queries += 1 if query_batch[1][q_index] == predicted_labels[ q_index]: total_distractor_correct += 1 # prediction_originators = support_batch[2][nearest_neighbour_indices] # total_easy_correct += np.sum( # np.logical_and(query_batch[1] == predicted_labels, # query_batch[2] == prediction_originators)) # total_easy_queries += np.sum(query_batch[2] == prediction_originators) # total_distractor_correct += np.sum( # np.logical_and(query_batch[1] == predicted_labels, # query_batch[2] != prediction_originators)) # total_distractor_queries += np.sum(query_batch[2] != prediction_originators) if episode % log_interval == 0: avg_acc = total_correct / total_queries ep_message = ("\tFew-shot Test: [Episode: {}/{}]\t" "Average accuracy: {:.7f}".format( episode, n_episodes, avg_acc)) if test_invariance: avg_easy_acc = total_easy_correct / total_easy_queries if total_easy_queries != 0 else 0. avg_dist_acc = total_distractor_correct / total_distractor_queries if total_distractor_queries != 0 else 0. ep_message += ( "\n\t\tEasy speaker accuracy: {:.7f}\tDistractor " "speaker accuracy: {:.7f}".format( avg_easy_acc, avg_dist_acc)) ep_message += ( "\n\t\tNum easy speakers: {}\tNum distractor speakers: {}" .format(total_easy_queries, total_distractor_queries)) logging.info(ep_message) # ---------------------------------------------------------------------- # Print stats: # ---------------------------------------------------------------------- avg_acc = total_correct / total_queries few_shot_message = ("Test set (few-shot): Average accuracy: " "{:.5f}".format(avg_acc)) if test_invariance: avg_easy_acc = total_easy_correct / total_easy_queries avg_dist_acc = total_distractor_correct / total_distractor_queries few_shot_message += ( "\n\t\tEasy speaker accuracy: {:.5f}\tDistractor " "speaker accuracy: {:.5f}".format(avg_easy_acc, avg_dist_acc)) few_shot_message += ( "\n\t\tNum easy speakers: {}\tNum distractor speakers: {}". format(total_easy_queries, total_distractor_queries)) logging.info(few_shot_message) test_summ_val = sess.run(test_summ, feed_dict={test_acc_input: avg_acc}) summary_writer.add_summary(test_summ_val, step) summary_writer.flush() with open(os.path.join(output_dir, 'test_result.txt'), 'w') as res_file: res_file.write(few_shot_message) # Testing complete logging.info("Testing complete.")