def train_few_shot_model(
    train_iterator,
    train_feed_dict,
    train_flag,
    train_loss,
    train_metrics,
    train_optimizer,
    n_epochs,
    max_batches,
    val_iterator,
    val_feed_dict,
    model_embedding,
    embed_input,
    query_input,
    support_memory_input,
    nearest_neighbour,
    n_episodes,
    log_interval=1,  # Number of batches to complete between logging
    model_dir='saved_models',  # Saved models written to <model_dir>/checkpoints
    summary_dir='summaries/train',  # Directory for writing summaries
    save_filename='trained_model',  # Checkpoint filename
    restore_checkpoint=None  # Resumes training from a specific checkpoint
):
    # Get the global step
    global_step = tf.train.get_or_create_global_step()
    # Get tf.summary tensors to evaluate
    summaries = tf.summary.merge_all()
    val_acc_input = tf.placeholder(TF_FLOAT)
    val_summary = tf.summary.scalar('val_few_shot_accuracy', val_acc_input)
    # Define variables to store the best validation accuracy and epoch
    epoch_var = tf.Variable(0, name='best_epoch')
    accuracy_var = tf.Variable(0., name='best_accuracy')
    # Define a saver for model checkpoints
    checkpoint_saver = tf.train.Saver(max_to_keep=5, save_relative_paths=True)
    best_saver = tf.train.Saver(save_relative_paths=True)

    # Define helper function for saving model and parameters
    def _save_checkpoint(epoch, save_best=False, save_first=False):
        checkpoint_path = os.path.join(model_dir, 'checkpoints', save_filename)
        saved = checkpoint_saver.save(sess, checkpoint_path, global_step=epoch)
        logging.info("Saved model checkpoint to file (epoch {}): {}"
                     "".format(epoch, saved))
        if save_best:
            best_path = os.path.join(model_dir, 'final_model', save_filename)
            best_saved = best_saver.save(sess, best_path)
            logging.info("Saved new best model to file (epoch {}): {}"
                         "".format(epoch, best_saved))
            return best_saved
        elif save_first:
            first_path = os.path.join(model_dir, 'checkpoints',
                                      'initial_model')
            first_saved = best_saver.save(sess, first_path)
            logging.info("Saved randomly initialized base model to file: {}"
                         "".format(first_saved))

    run_options = tf.RunOptions(report_tensor_allocations_upon_oom=True)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # allow gpu memory growth
    # Start tf.Session to train and validate model
    with tf.Session(config=config) as sess:
        # ----------------------------------------------------------------------
        # Load model and log some debug info:
        # ----------------------------------------------------------------------
        try:  # restore from model checkpoint
            if restore_checkpoint is not None:  # use specific checkpoint
                restore_path = os.path.join(model_dir, 'checkpoints',
                                            restore_checkpoint)
                if not os.path.isfile('{}.index'.format(restore_path)):
                    restore_path = restore_checkpoint  # possibly full path?
            else:  # use latest checkpoint if available
                restore_path = tf.train.latest_checkpoint(
                    os.path.join(model_dir, 'checkpoints'))
            checkpoint_saver.restore(sess, restore_path)
            logging.info(
                "Model restored from checkpoint: {}".format(restore_path))
            start_epoch = int(restore_path.split('-')[-1])
        except ValueError:  # no checkpoints, initialize variables from scratch
            sess.run(tf.global_variables_initializer())
            start_epoch = 0
            _save_checkpoint(start_epoch, save_first=True)  # reproducibility

        # Evaluate starting global step, and previous best accuracy and epoch
        step = sess.run(global_step)
        logging.info("Training from: Epoch: {}\tGlobal Step: {}".format(
            start_epoch + 1, step))
        best_epoch, best_val_acc = sess.run([epoch_var, accuracy_var])
        logging.info("Current best model: Epoch: {}\tValidation accuracy: "
                     "{:.5f}".format(best_epoch + 1, best_val_acc))

        # Create session summary writer
        summary_writer = tf.summary.FileWriter(
            os.path.join(model_dir, summary_dir,
                         datetime.datetime.now().strftime("%Hh%Mm%Ss_%f")),
            sess.graph)

        # TODO(rpeloff) If siamese ...
        # Get some triplet pairs, and display on tensorboard
        # x_triplet, y_triplet = train_iterator.get_next()
        # sess.run(train_iterator.initializer, feed_dict=train_feed_dict)  # init validation set iterator
        # anch_summ = tf.summary.image('triplet_anchor_images', x_triplet[0], 5)
        # same_summ = tf.summary.image('triplet_same_images', x_triplet[1], 5)
        # diff_summ = tf.summary.image('triplet_different_images', x_triplet[2], 5)
        # x_trip_batch, y_trip_batch, anch_images, same_images, diff_images = sess.run(
        #     [x_triplet, y_triplet, anch_summ, same_summ, diff_summ])
        # summary_writer.add_summary(anch_images, step)
        # summary_writer.add_summary(same_images, step)
        # summary_writer.add_summary(diff_images, step)
        # summary_writer.flush()

        # Get support/query few-shot set, and display one episode on tensorboard
        support_set, query_set = val_iterator.get_next()
        sess.run(val_iterator.initializer,
                 feed_dict=val_feed_dict)  # init validation set iterator
        s_summ = tf.summary.image('support_set_images', support_set[0], 10)
        q_summ = tf.summary.image('query_set_images', query_set[0], 10)
        support_batch, query_batch, s_images, q_images = sess.run(
            [support_set, query_set, s_summ, q_summ])
        summary_writer.add_summary(s_images, step)
        summary_writer.add_summary(q_images, step)
        summary_writer.flush()
        # Save figures to pdf for later use ...
        for image, label, speaker in zip(*support_batch):
            utils.save_image(np.squeeze(image, axis=-1),
                             filename=os.path.join(
                                 model_dir, 'train_images',
                                 '{}_{}_{}_{}_{}.pdf'.format(
                                     'support', 'l', label.decode("utf-8"),
                                     's', speaker.decode("utf-8"))),
                             cmap='inferno')
        for image, label, speaker in zip(*query_batch):
            utils.save_image(np.squeeze(image, axis=-1),
                             filename=os.path.join(
                                 model_dir, 'train_images',
                                 '{}_{}_{}_{}_{}.pdf'.format(
                                     'query', 'l', label.decode("utf-8"), 's',
                                     speaker.decode("utf-8"))),
                             cmap='inferno')

        # ----------------------------------------------------------------------
        # Training:
        # ----------------------------------------------------------------------
        for epoch in range(start_epoch, n_epochs):
            logging.info("Epoch: {}/{} [Step: {}]"
                         "".format(epoch + 1, n_epochs, step))
            sess.run(train_iterator.initializer,
                     feed_dict=train_feed_dict)  # init train dataset iterator
            avg_loss = 0.
            n_batches_completed = 0
            for i in range(max_batches):
                try:
                    # TODO(reploff) Add embeddings and labels to visualize how
                    # embeddings change over training?
                    _, loss_val, summary_vals, metric_vals, step = sess.run(
                        [
                            train_optimizer, train_loss, summaries,
                            [m for m in train_metrics.values()], global_step
                        ],
                        feed_dict={train_flag: True})

                    # Write summaries for tensorboard and log some info
                    summary_writer.add_summary(summary_vals, step)
                    summary_writer.flush()
                    n_batches_completed += 1
                    avg_loss += loss_val
                    if n_batches_completed % log_interval == 0:
                        batch_message = (
                            "\tTrain: [Batch: {}/{}]\tLoss: {:.7f}"
                            "".format(n_batches_completed, max_batches,
                                      loss_val))
                        for metric_key, metric_val in zip(
                            [k for k in train_metrics.keys()], metric_vals):
                            batch_message += "\t{}: {}".format(
                                metric_key, metric_val)
                        logging.info(batch_message)
                except tf.errors.OutOfRangeError:  # catch pipeline out of range
                    break
            # ------------------------------------------------------------------
            # Few-shot validation:
            # ------------------------------------------------------------------
            total_queries = 0
            total_correct = 0
            sess.run(val_iterator.initializer,
                     feed_dict=val_feed_dict)  # init validation set iterator
            for episode in range(n_episodes):
                support_batch, query_batch = sess.run([support_set, query_set])
                # Get embeddings and classify queries with 1-NN on support set
                support_embeddings = sess.run(model_embedding,
                                              feed_dict={
                                                  embed_input:
                                                  support_batch[0],
                                                  train_flag: False
                                              })
                query_embeddings = sess.run(model_embedding,
                                            feed_dict={
                                                embed_input: query_batch[0],
                                                train_flag: False
                                            })
                nearest_neighbour_indices = sess.run(nearest_neighbour,
                                                     feed_dict={
                                                         query_input:
                                                         query_embeddings,
                                                         support_memory_input:
                                                         support_embeddings
                                                     })
                # Calculate and store number of correct predictions
                predicted_labels = support_batch[1][nearest_neighbour_indices]
                total_correct += np.sum(query_batch[1] == predicted_labels)
                total_queries += query_batch[1].shape[0]
                if episode % int(n_episodes / 5) == 0:
                    avg_acc = total_correct / total_queries
                    ep_message = ("\tFew-shot Test: [Episode: {}/{}]\t"
                                  "Average accuracy: {:.7f}".format(
                                      episode, n_episodes, avg_acc))
                    logging.info(ep_message)
            # ------------------------------------------------------------------
            # Print stats and early-stopping:
            # ------------------------------------------------------------------
            # Print epoch train stats
            avg_loss = avg_loss / n_batches_completed
            epoch_message = ("Epoch: {}/{} [Step: {}]\tTrain set: Average "
                             "loss: {:.5f}".format(epoch + 1, n_epochs, step,
                                                   avg_loss))
            logging.info(epoch_message)
            # Print epoch few-shot validation stats
            avg_acc = total_correct / total_queries
            few_shot_message = ("Epoch: {}/{} [Step: {}]\tValidation set (few-"
                                "shot): Average accuracy: {:.5f}".format(
                                    epoch + 1, n_epochs, step, avg_acc))
            logging.info(few_shot_message)
            val_summ = sess.run(val_summary,
                                feed_dict={val_acc_input: avg_acc})
            summary_writer.add_summary(val_summ, step)
            summary_writer.flush()
            # Check if this is a new best model
            if avg_acc > best_val_acc:
                best_epoch = epoch
                best_val_acc = avg_acc
                sess.run([
                    tf.assign(epoch_var, best_epoch),
                    tf.assign(accuracy_var, best_val_acc)
                ])
                _save_checkpoint(epoch + 1, save_best=True)  # save best model
                with open(os.path.join(model_dir, 'train_result.txt'),
                          'w') as res_file:
                    res_file.write("Epoch: {}\tTrain loss: {:.5f}\t"
                                   "Validation accuracy: {:.5f}".format(
                                       epoch + 1, avg_loss, avg_acc))
        # Training complete, print final (best) model stats:
        logging.info("Training complete. Best model found at epoch {} with "
                     "validation accuracy {:.5f}.".format(
                         best_epoch + 1, best_val_acc))
def test_mulitmodal_few_shot_model(
        test_feed_dict,
        # Speech test params ...
        speech_graph,
        speech_train_flag,
        speech_test_iterator,
        speech_model_embedding,
        speech_embed_input,
        # Vision test params ...
        vision_graph,
        vision_train_flag,
        vision_test_iterator,
        vision_model_embedding,
        vision_embed_input,
        # Nearest neigbour params ...
        query_input,
        support_memory_input,
        nearest_neighbour,
        n_episodes,
        query_type,
        test_pixels=False,
        test_dtw=False,
        dtw_cost_func=None,
        dtw_post_process=None,
        test_invariance=False,
        # Other params ...
        log_interval=1,
        model_dir='saved_models',
        speech_model_dir='saved_models/speech',
        vision_model_dir='saved_models/vision',
        summary_dir='summaries/test',
        speech_restore_checkpoint=None,
        vision_restore_checkpoint=None):
    # Create tf.Session's for speech, vision, and general models
    speech_session = tf.Session(graph=speech_graph)
    vision_session = tf.Session(graph=vision_graph)
    general_session = tf.Session()  # default graph
    # Get model global steps
    with speech_graph.as_default():
        speech_global_step = tf.train.get_or_create_global_step()
        speech_step = 0
    with vision_graph.as_default():
        vision_global_step = tf.train.get_or_create_global_step()
        vision_step = 0
    # --------------------------------------------------------------------------
    # Load models (unless using DTW or pixel matching) and log some debug info:
    # --------------------------------------------------------------------------
    if not test_dtw:
        with speech_session.as_default(), speech_graph.as_default():  #pylint: disable=E1129
            try:  # restore speech from model checkpoint
                speech_checkpoint_saver = tf.train.Saver(
                    save_relative_paths=True)
                if speech_restore_checkpoint is not None:  # specific checkpoint
                    restore_path = os.path.join(speech_model_dir,
                                                'checkpoints',
                                                speech_restore_checkpoint)
                    if not os.path.isfile('{}.index'.format(restore_path)):
                        restore_path = speech_restore_checkpoint  # full path
                else:  # use best model if available
                    final_model_dir = os.path.join(speech_model_dir,
                                                   'final_model')
                    restore_path = tf.train.latest_checkpoint(final_model_dir)
                    if restore_path is None:
                        logging.info(
                            "No best model checkpoint could be found "
                            "in directory: {}".format(final_model_dir))
                        return  # exit ...
                speech_checkpoint_saver.restore(speech_session, restore_path)
                logging.info(
                    "Speech model restored from checkpoint: {}".format(
                        restore_path))
            except ValueError:  # no checkpoints, inform and exit ...
                logging.info(
                    "Vision model checkpoint could not found at restore "
                    "path: {}".format(restore_path))
                return  # exit ...
            # Evaluate global speech model was trained to
            speech_step = speech_session.run(speech_global_step)
            logging.info("Testing speech model from: Global Step: {}"
                         "".format(speech_step))
    else:
        logging.info("Testing speech model with dynamic time warping (DTW).")

    if not test_pixels:
        with vision_session.as_default(), vision_graph.as_default():  #pylint: disable=E1129
            try:  # restore vision from model checkpoint
                vision_checkpoint_saver = tf.train.Saver(
                    save_relative_paths=True)
                if vision_restore_checkpoint is not None:  # specific checkpoint
                    restore_path = os.path.join(vision_model_dir,
                                                'checkpoints',
                                                vision_restore_checkpoint)
                    if not os.path.isfile('{}.index'.format(restore_path)):
                        restore_path = vision_restore_checkpoint  # full path
                else:  # use best model if available
                    final_model_dir = os.path.join(vision_model_dir,
                                                   'final_model')
                    restore_path = tf.train.latest_checkpoint(final_model_dir)
                    if restore_path is None:
                        logging.info(
                            "No best model checkpoint could be found "
                            "in directory: {}".format(final_model_dir))
                        return  # exit ...
                vision_checkpoint_saver.restore(vision_session, restore_path)
                logging.info(
                    "Vision model restored from checkpoint: {}".format(
                        restore_path))
            except ValueError:  # no checkpoints, inform and exit ...
                logging.info(
                    "Speech model checkpoint could not found at restore "
                    "path: {}".format(restore_path))
                return  # exit ...
            # Evaluate global vision model was trained to
            vision_step = vision_session.run(vision_global_step)
            logging.info("Testing vision model from: Global Step: {}".format(
                vision_step))
    else:
        logging.info("Testing vision model with pure pixel matching.")

    # Create general session summary writer and few-shot accuracy summary
    summary_writer = tf.summary.FileWriter(
        os.path.join(model_dir, summary_dir,
                     datetime.datetime.now().strftime("%Hh%Mm%Ss_%f")),
        general_session.graph)
    # Get tf.summary tensor to evaluate for few-shot accuracy
    test_acc_input = tf.placeholder(TF_FLOAT)
    test_summ = tf.summary.scalar('test_few_shot_accuracy', test_acc_input)
    # Get speech and vision support/query/matching few-shot sets, and save
    # one episode to tensorboard and files for debugging
    speech_support_set, speech_query_set = speech_test_iterator.get_next()
    vision_support_set, vision_query_set = vision_test_iterator.get_next()
    general_session.run(
        [speech_test_iterator.initializer, vision_test_iterator.initializer],
        feed_dict=test_feed_dict)  # init test set iterator
    speech_s_summ = tf.summary.image('speech_support_set_images',
                                     speech_support_set[0], 10)
    vision_s_summ = tf.summary.image('vision_support_set_images',
                                     vision_support_set[0], 10)
    if query_type == 'speech':  # speech query, image matching set
        speech_q_summ = tf.summary.image('speech_query_set_images',
                                         speech_query_set[0], 10)
        vision_q_summ = tf.summary.image('vision_matching_set_images',
                                         vision_query_set[0], 10)
    else:  # image query, speech matching set
        speech_q_summ = tf.summary.image('speech_matching_set_images',
                                         speech_query_set[0], 10)
        vision_q_summ = tf.summary.image('vision_query_set_images',
                                         vision_query_set[0], 10)
    (speech_s_batch, speech_q_batch, vision_s_batch, vision_q_batch,
     speech_s_images, speech_q_images, vision_s_images,
     vision_q_images) = (general_session.run([
         speech_support_set, speech_query_set, vision_support_set,
         vision_query_set, speech_s_summ, speech_q_summ, vision_s_summ,
         vision_q_summ
     ]))
    # Save figures to pdf for later use ...
    for index, (image, label, speaker) in enumerate(zip(*speech_s_batch)):
        if test_dtw:
            image = dtw_post_process(image)
        else:
            image = np.squeeze(image, axis=-1)
        utils.save_image(image,
                         filename=os.path.join(
                             model_dir, 'multimodal_test_images',
                             '{}_{}_{}_{}_{}_{}_{}.pdf'.format(
                                 'speech', 'support', index, 'label',
                                 label.decode("utf-8"), 'speaker',
                                 speaker.decode("utf-8"))),
                         cmap='inferno')
    for index, (image, label, speaker) in enumerate(zip(*speech_q_batch)):
        if test_dtw:
            image = dtw_post_process(image)
        else:
            image = np.squeeze(image, axis=-1)
        set_label = 'query' if query_type == 'speech' else 'matching'
        utils.save_image(image,
                         filename=os.path.join(
                             model_dir, 'multimodal_test_images',
                             '{}_{}_{}_{}_{}_{}_{}.pdf'.format(
                                 'speech', set_label, index, 'label',
                                 label.decode("utf-8"), 'speaker',
                                 speaker.decode("utf-8"))),
                         cmap='inferno')
    for index, (image, label) in enumerate(zip(*vision_s_batch)):
        utils.save_image(np.squeeze(image, axis=-1),
                         filename=os.path.join(
                             model_dir, 'multimodal_test_images',
                             '{}_{}_{}_{}_{}.pdf'.format(
                                 'vision', 'support', index, 'label',
                                 label.decode("utf-8"))),
                         cmap='gray_r')
    for index, (image, label) in enumerate(zip(*vision_q_batch)):
        set_label = 'query' if query_type == 'vision' else 'matching'
        utils.save_image(np.squeeze(image, axis=-1),
                         filename=os.path.join(
                             model_dir, 'multimodal_test_images',
                             '{}_{}_{}_{}_{}.pdf'.format(
                                 'vision', set_label, index, 'label',
                                 label.decode("utf-8"))),
                         cmap='gray_r')
    # Save summary images with general summary writer
    summary_writer.add_summary(speech_s_images, speech_step)
    summary_writer.add_summary(speech_q_images, speech_step)
    summary_writer.add_summary(vision_s_images, vision_step)
    summary_writer.add_summary(vision_q_images, vision_step)
    summary_writer.flush()

    # --------------------------------------------------------------------------
    # Cross-modal few-shot testing:
    # --------------------------------------------------------------------------
    total_queries = 0
    total_correct = 0
    # Speaker invariance accuracy counters
    total_easy_queries = 0
    total_easy_correct = 0
    total_distractor_queries = 0
    total_distractor_correct = 0
    general_session.run(
        [speech_test_iterator.initializer, vision_test_iterator.initializer],
        feed_dict=test_feed_dict)  # init test set iterator
    few_shot_set = [(speech_support_set, speech_query_set),
                    (vision_support_set, vision_query_set)]
    for episode in range(n_episodes):
        # Get next few-shot batch
        episode_batch = general_session.run(few_shot_set)
        speech_batch = episode_batch[0]
        vision_batch = episode_batch[1]
        # Get embeddings and classify queries with 1-NN on support set
        with speech_session.as_default(), speech_graph.as_default():  #pylint: disable=E1129
            speech_support_embeddings = speech_session.run(
                speech_model_embedding,
                feed_dict={
                    speech_embed_input: speech_batch[0][0],
                    speech_train_flag: False
                })
            speech_query_embeddings = speech_session.run(
                speech_model_embedding,
                feed_dict={
                    speech_embed_input: speech_batch[1][0],
                    speech_train_flag: False
                })
        with vision_session.as_default(), vision_graph.as_default():  #pylint: disable=E1129
            vision_support_embeddings = vision_session.run(
                vision_model_embedding,
                feed_dict={
                    vision_embed_input: vision_batch[0][0],
                    vision_train_flag: False
                })
            vision_query_embeddings = vision_session.run(
                vision_model_embedding,
                feed_dict={
                    vision_embed_input: vision_batch[1][0],
                    vision_train_flag: False
                })
        # Speech query cross-modal matching to images
        if query_type == 'speech':
            pred_message = ""
            if not test_dtw:  # test speech with fast cosine 1-NN memory model
                s_nearest_neighbour_indices = general_session.run(
                    nearest_neighbour,
                    feed_dict={
                        query_input: speech_query_embeddings,
                        support_memory_input: speech_support_embeddings
                    })
            else:  # test speech with dynamic time warping baseline
                costs = [[
                    dtw_cost_func(
                        dtw_post_process(speech_query_embeddings[i]),
                        dtw_post_process(speech_support_embeddings[j]), True)
                    for j in range(len(speech_support_embeddings))
                ] for i in range(len(speech_query_embeddings))]
                s_nearest_neighbour_indices = [
                    np.argmin(costs[i]) for i in range(len(costs))
                ]
            # Get cross-modal matches in image support set and their labels
            query_cross_matches = vision_support_embeddings[
                s_nearest_neighbour_indices]
            actual_labels = speech_batch[1][1]
            # Find cross-modal matches in the image matching set
            v_nearest_neighbour_indices = general_session.run(
                nearest_neighbour,
                feed_dict={
                    query_input: query_cross_matches,
                    support_memory_input: vision_query_embeddings
                })
            predicted_labels = vision_batch[1][1][v_nearest_neighbour_indices]
            # Images 'z' and 'o' treated same regardless of different speech classes
            actual_labels_update = np.array(
                [label if label != b'o' else b'z' for label in actual_labels])
            predicted_labels_update = np.array([
                label if label != b'o' else b'z' for label in predicted_labels
            ])
            # Put together some debug info ...
            pred_message += "\t\tActual speech query labels:\t\t{}".format(
                actual_labels)
            pred_message += "\n\t\tPredicted speech support labels:\t{}".format(
                speech_batch[0][1][s_nearest_neighbour_indices])
            pred_message += "\n\t\tAssociated vision support labels:\t{}".format(
                vision_batch[0][1][s_nearest_neighbour_indices])
            pred_message += "\n\t\tPredicted vision matching labels:\t{}".format(
                predicted_labels)
            pred_message += "\n\t\tUpdated speech query labels ('o'=='z'):\t{}".format(
                actual_labels_update)
            pred_message += "\n\t\tUpdated vision match labels ('o'=='z'):\t{}".format(
                predicted_labels_update)
            # Update accuracy counters and log info
            total_correct += np.sum(
                actual_labels_update == predicted_labels_update)
            total_queries += speech_batch[1][1].shape[0]
            if test_invariance:
                # Count queries and predictions with easy/distractor speakers
                for q_index in range(speech_batch[1][1].shape[0]):
                    n_same_speaker = np.sum(
                        np.logical_and(
                            speech_batch[1][1][q_index] == speech_batch[0][1],
                            speech_batch[1][2][q_index] == speech_batch[0][2]))
                    if n_same_speaker > 0:  # easy speakers
                        total_easy_queries += 1
                        if actual_labels_update[
                                q_index] == predicted_labels_update[q_index]:
                            total_easy_correct += 1
                    else:  # distractor speakers
                        total_distractor_queries += 1
                        if actual_labels_update[
                                q_index] == predicted_labels_update[q_index]:
                            total_distractor_correct += 1
            if episode % log_interval == 0:
                avg_acc = total_correct / total_queries
                ep_message = ("\tFew-shot Test: [Episode: {}/{}]\t"
                              "Average accuracy: {:.7f}".format(
                                  episode, n_episodes, avg_acc))
                if test_invariance:
                    avg_easy_acc = total_easy_correct / total_easy_queries if total_easy_queries != 0 else 0.
                    avg_dist_acc = total_distractor_correct / total_distractor_queries if total_distractor_queries != 0 else 0.
                    ep_message += (
                        "\n\t\tEasy speaker accuracy: {:.7f}\tDistractor "
                        "speaker accuracy: {:.7f}".format(
                            avg_easy_acc, avg_dist_acc))
                    ep_message += (
                        "\n\t\tNum easy speakers: {}\tNum distractor speakers: {}"
                        .format(total_easy_queries, total_distractor_queries))
                logging.info(ep_message)
                #                 print("Speech support labels:", speech_batch[0][1])
                #                 print("Vision support labels:", vision_batch[0][1])
                logging.info(pred_message)
        # Image query cross-modal matching to speech
        else:
            pred_message = ""
            v_nearest_neighbour_indices = general_session.run(
                nearest_neighbour,
                feed_dict={
                    query_input: vision_query_embeddings,
                    support_memory_input: vision_support_embeddings
                })
            # Get cross-modal matches in speech support set and their labels
            query_cross_matches = speech_support_embeddings[
                v_nearest_neighbour_indices]
            actual_labels = vision_batch[1][1]
            # Find cross-modal matches in the speech matching set
            if not test_dtw:  # test speech with fast cosine 1-NN memory model
                s_nearest_neighbour_indices = general_session.run(
                    nearest_neighbour,
                    feed_dict={
                        query_input: query_cross_matches,
                        support_memory_input: speech_query_embeddings
                    })
            else:  # test speech with dynamic time warping baseline
                costs = [[
                    dtw_cost_func(dtw_post_process(query_cross_matches[i]),
                                  dtw_post_process(speech_query_embeddings[j]),
                                  True)
                    for j in range(len(speech_query_embeddings))
                ] for i in range(len(query_cross_matches))]
                s_nearest_neighbour_indices = [
                    np.argmin(costs[i]) for i in range(len(costs))
                ]
            predicted_labels = speech_batch[1][1][s_nearest_neighbour_indices]
            # Images 'z' and 'o' treated same regardless of different speech classes
            actual_labels_update = np.array(
                [label if label != b'o' else b'z' for label in actual_labels])
            predicted_labels_update = np.array([
                label if label != b'o' else b'z' for label in predicted_labels
            ])
            # Put together some debug info ...
            pred_message += "\t\tActual vision query labels:\t\t{}".format(
                actual_labels)
            pred_message += "\n\t\tPredicted vision support labels:\t{}".format(
                vision_batch[0][1][v_nearest_neighbour_indices])
            pred_message += "\n\t\tAssociated speech support labels:\t{}".format(
                speech_batch[0][1][v_nearest_neighbour_indices])
            pred_message += '\n\t\tPredicted speech matching labels:\t{}'.format(
                predicted_labels)
            pred_message += "\n\t\tUpdated vision query labels:\t\t{}".format(
                actual_labels_update)
            pred_message += '\n\t\tUpdated speech prediction labels:\t{}'.format(
                predicted_labels_update)
            # Update accuracy counters and log info
            total_correct += np.sum(
                actual_labels_update == predicted_labels_update)
            total_queries += vision_batch[1][1].shape[0]
            if episode % log_interval == 0:
                avg_acc = total_correct / total_queries
                ep_message = ("\tFew-shot Test: [Episode: {}/{}]\t"
                              "Average accuracy: {:.7f}".format(
                                  episode, n_episodes, avg_acc))
                logging.info(ep_message)
                logging.info(pred_message)

    # ------------------------------------------------------------------
    # Print stats:
    # ------------------------------------------------------------------
    avg_acc = total_correct / total_queries
    few_shot_message = ("Test set (few-shot): Average accuracy: "
                        "{:.5f}".format(avg_acc))
    if test_invariance:
        avg_easy_acc = total_easy_correct / total_easy_queries
        avg_dist_acc = total_distractor_correct / total_distractor_queries
        few_shot_message += ("\n\t\tEasy speaker accuracy: {:.5f}\tDistractor "
                             "speaker accuracy: {:.5f}".format(
                                 avg_easy_acc, avg_dist_acc))
        few_shot_message += (
            "\n\t\tNum easy speakers: {}\tNum distractor speakers: {}".format(
                total_easy_queries, total_distractor_queries))
    logging.info(few_shot_message)
    test_summ = general_session.run(test_summ,
                                    feed_dict={test_acc_input: avg_acc})
    summary_writer.add_summary(test_summ, 0)
    summary_writer.flush()
    with open(os.path.join(model_dir, 'test_result.txt'), 'w') as res_file:
        res_file.write(few_shot_message)
    # Testing complete
    logging.info("Testing complete.")

    speech_session.close()
    vision_session.close()
    general_session.close()
Пример #3
0
def test_few_shot_model(
        train_flag,
        test_iterator,
        test_feed_dict,
        model_embedding,
        embed_input,
        query_input,
        support_memory_input,
        nearest_neighbour,
        n_episodes,
        test_pixels=False,
        log_interval=1,
        model_dir='saved_models',
        output_dir='.',
        summary_dir='summaries/test',
        restore_checkpoint=None):
   # Get the global step tensor and set intial step value
    global_step = tf.train.get_or_create_global_step()
    step = 0
    # Define a saver to load model checkpoint
    checkpoint_saver = tf.train.Saver(save_relative_paths=True)
    # Start tf.Session to test model
    with tf.Session() as sess:
        # ----------------------------------------------------------------------
        # Load model (unless using pixels) and log some debug info:
        # ----------------------------------------------------------------------
        if not test_pixels:
            try:  # restore from model checkpoint
                if restore_checkpoint is not None:  # use specific checkpoint
                    restore_path = os.path.join(
                        model_dir, 'checkpoints', restore_checkpoint)
                    if not os.path.isfile('{}.index'.format(restore_path)):
                        restore_path = restore_checkpoint  # possibly full path?
                else:  # use best model if available
                    final_model_dir = os.path.join(model_dir, 'final_model')
                    restore_path = tf.train.latest_checkpoint(final_model_dir)
                    if restore_path is None:
                        logging.info("No best model checkpoint could be found "
                                     "in directory: {}".format(final_model_dir))
                        return  # exit ... 
                checkpoint_saver.restore(sess, restore_path)
                logging.info("Model restored from checkpoint file: "
                             "{}".format(restore_path))
            except ValueError:  # no checkpoints, inform and exit ...
                logging.info("Model checkpoint could not found at restore "
                            "path: {}".format(restore_path))
                return  # exit ... 
            
            # Evaluate global step model was trained to
            step = sess.run(global_step)
            logging.info("Testing from: Global Step: {}"
                        .format(step))
        else:
            logging.info("Testing vision model with pure pixel matching.")
        
        # Create session summary writer
        summary_writer = tf.summary.FileWriter(os.path.join(
            output_dir, summary_dir,
            datetime.datetime.now().strftime("%Hh%Mm%Ss_%f")), sess.graph)
        # Get tf.summary tensor to evaluate for few-shot accuracy
        test_acc_input = tf.placeholder(TF_FLOAT)
        test_summ = tf.summary.scalar('test_few_shot_accuracy', test_acc_input)
        # Get support/query few-shot set, and display one episode on tensorboard
        support_set, query_set = test_iterator.get_next()
        sess.run(test_iterator.initializer, feed_dict=test_feed_dict)
        s_summ = tf.summary.image('support_set_images', support_set[0], 10)
        q_summ = tf.summary.image('query_set_images', query_set[0], 10)
        support_batch, query_batch, s_images, q_images = sess.run(
            [support_set, query_set, s_summ, q_summ])
        summary_writer.add_summary(s_images, step)
        summary_writer.add_summary(q_images, step)
        summary_writer.flush()
        # Save figures to pdf for later use ...
        for index, (image, label) in enumerate(zip(*support_batch)):
            utils.save_image(np.squeeze(image, axis=-1), filename=os.path.join(
                output_dir, 'test_images', '{}_{}_{}_{}.pdf'.format(
                    'support', index, 'label', label.decode("utf-8"))),
                    cmap='gray_r')
        for index, (image, label) in enumerate(zip(*query_batch)):
            utils.save_image(np.squeeze(image, axis=-1), filename=os.path.join(
                output_dir, 'test_images', '{}_{}_{}_{}.pdf'.format(
                    'query', index, 'label', label.decode("utf-8"))),
                    cmap='gray_r')

        # ----------------------------------------------------------------------
        # Few-shot testing:
        # ----------------------------------------------------------------------
        total_queries = 0
        total_correct = 0
        sess.run(test_iterator.initializer, feed_dict=test_feed_dict)
        for episode in range(n_episodes):    
            support_batch, query_batch = sess.run([support_set, query_set])
            # Get embeddings and classify queries with 1-NN on support set
            support_embeddings = sess.run(model_embedding, 
                feed_dict={embed_input: support_batch[0], train_flag: False})
            query_embeddings = sess.run(model_embedding, 
                feed_dict={embed_input: query_batch[0], train_flag: False})
            nearest_neighbour_indices = sess.run(nearest_neighbour,
                feed_dict={query_input: query_embeddings,
                            support_memory_input: support_embeddings})
            # Calculate and store number of correct predictions
            predicted_labels = support_batch[1][nearest_neighbour_indices] 
            total_correct += np.sum(query_batch[1] == predicted_labels)
            total_queries += query_batch[1].shape[0]
            if episode % log_interval == 0:
                avg_acc = total_correct/total_queries
                ep_message = ("\tFew-shot Test: [Episode: {}/{}]\t"
                                "Average accuracy: {:.7f}".format(
                                    episode, n_episodes, avg_acc))
                logging.info(ep_message)
        # ----------------------------------------------------------------------
        # Print stats:
        # ----------------------------------------------------------------------
        avg_acc = total_correct/total_queries
        few_shot_message = ("Test set (few-shot): Average accuracy: "
                            "{:.5f}".format(avg_acc))
        logging.info(few_shot_message)
        test_summ = sess.run(test_summ, feed_dict={test_acc_input: avg_acc})
        summary_writer.add_summary(test_summ, step)
        summary_writer.flush()
        with open(os.path.join(output_dir, 'test_result.txt'), 'w') as res_file:
            res_file.write(few_shot_message)
    # Testing complete
    logging.info("Testing complete.")
Пример #4
0
def test_few_shot_model(train_flag,
                        test_iterator,
                        test_feed_dict,
                        model_embedding,
                        embed_input,
                        query_input,
                        support_memory_input,
                        nearest_neighbour,
                        n_episodes,
                        test_dtw=False,
                        dtw_cost_func=None,
                        dtw_post_process=None,
                        test_invariance=False,
                        log_interval=1,
                        model_dir='saved_models',
                        output_dir='.',
                        summary_dir='summaries/test',
                        restore_checkpoint=None):
    # Get the global step tensor and set intial step value
    global_step = tf.train.get_or_create_global_step()
    step = 0
    # Define a saver to load model checkpoint
    checkpoint_saver = tf.train.Saver(save_relative_paths=True)
    # Start tf.Session to test model
    with tf.Session() as sess:
        # ----------------------------------------------------------------------
        # Load model (unless using DTW) and log some debug info:
        # ----------------------------------------------------------------------
        if not test_dtw:
            try:  # restore from model checkpoint
                if restore_checkpoint is not None:  # use specific checkpoint
                    restore_path = os.path.join(model_dir, 'checkpoints',
                                                restore_checkpoint)
                    if not os.path.isfile('{}.index'.format(restore_path)):
                        restore_path = restore_checkpoint  # possibly full path?
                else:  # use best model if available
                    final_model_dir = os.path.join(model_dir, 'final_model')
                    restore_path = tf.train.latest_checkpoint(final_model_dir)
                    if restore_path is None:
                        logging.info(
                            "No best model checkpoint could be found "
                            "in directory: {}".format(final_model_dir))
                        return  # exit ...
                checkpoint_saver.restore(sess, restore_path)
                logging.info("Model restored from checkpoint file: "
                             "{}".format(restore_path))
            except ValueError:  # no checkpoints, inform and exit ...
                logging.info("Model checkpoint could not found at restore "
                             "path: {}".format(restore_path))
                return  # exit ...

            # Evaluate global step model was trained to
            step = sess.run(global_step)
            logging.info("Testing from: Global Step: {}".format(step))
        else:
            logging.info(
                "Testing speech model with dynamic time warping (DTW).")

        # Create session summary writer
        summary_writer = tf.summary.FileWriter(
            os.path.join(output_dir, summary_dir,
                         datetime.datetime.now().strftime("%Hh%Mm%Ss_%f")),
            sess.graph)
        # Get tf.summary tensor to evaluate for few-shot accuracy
        test_acc_input = tf.placeholder(TF_FLOAT)
        test_summ = tf.summary.scalar('test_few_shot_accuracy', test_acc_input)
        # Get support/query few-shot set, and display one episode on tensorboard
        support_set, query_set = test_iterator.get_next()
        sess.run(test_iterator.initializer, feed_dict=test_feed_dict)
        s_summ = tf.summary.image('support_set_images', support_set[0], 10)
        q_summ = tf.summary.image('query_set_images', query_set[0], 10)
        support_batch, query_batch, s_images, q_images = sess.run(
            [support_set, query_set, s_summ, q_summ])
        summary_writer.add_summary(s_images, step)
        summary_writer.add_summary(q_images, step)
        summary_writer.flush()
        # Save figures to pdf for later use ...
        for index, (image, label, speaker) in enumerate(zip(*support_batch)):
            if test_dtw:
                image = dtw_post_process(image)
            else:
                image = np.squeeze(image, axis=-1)
            utils.save_image(image,
                             filename=os.path.join(
                                 output_dir, 'test_images',
                                 '{}_{}_{}_{}_{}_{}.pdf'.format(
                                     'support', index, 'label',
                                     label.decode("utf-8"), 'speaker',
                                     speaker.decode("utf-8"))),
                             cmap='inferno')
        for index, (image, label, speaker) in enumerate(zip(*query_batch)):
            if test_dtw:
                image = dtw_post_process(image)
            else:
                image = np.squeeze(image, axis=-1)
            utils.save_image(image,
                             filename=os.path.join(
                                 output_dir, 'test_images',
                                 '{}_{}_{}_{}_{}_{}.pdf'.format(
                                     'query', index, 'label',
                                     label.decode("utf-8"), 'speaker',
                                     speaker.decode("utf-8"))),
                             cmap='inferno')

        # ----------------------------------------------------------------------
        # Few-shot testing:
        # ----------------------------------------------------------------------
        # Few-shot accuracy counters
        total_queries = 0
        total_correct = 0
        # Speaker invariance accuracy counters
        total_easy_queries = 0
        total_easy_correct = 0
        total_distractor_queries = 0
        total_distractor_correct = 0
        sess.run(test_iterator.initializer, feed_dict=test_feed_dict)
        for episode in range(n_episodes):
            support_batch, query_batch = sess.run([support_set, query_set])
            # Get embeddings and classify queries with 1-NN on support set
            support_embeddings = sess.run(model_embedding,
                                          feed_dict={
                                              embed_input: support_batch[0],
                                              train_flag: False
                                          })
            query_embeddings = sess.run(model_embedding,
                                        feed_dict={
                                            embed_input: query_batch[0],
                                            train_flag: False
                                        })
            if not test_dtw:  # test with fast cosine 1-NN memory model
                nearest_neighbour_indices = sess.run(nearest_neighbour,
                                                     feed_dict={
                                                         query_input:
                                                         query_embeddings,
                                                         support_memory_input:
                                                         support_embeddings
                                                     })
            else:  # test with dynamic time warping
                costs = [[
                    dtw_cost_func(dtw_post_process(query_embeddings[i]),
                                  dtw_post_process(support_embeddings[j]),
                                  True) for j in range(len(support_embeddings))
                ] for i in range(len(query_embeddings))]
                nearest_neighbour_indices = [
                    np.argmin(costs[i]) for i in range(len(costs))
                ]

            # Calculate and store number of correct predictions
            predicted_labels = support_batch[1][nearest_neighbour_indices]
            total_correct += np.sum(query_batch[1] == predicted_labels)
            total_queries += query_batch[1].shape[0]
            if test_invariance:
                # Count queries and predictions with easy/distractor speakers
                for q_index in range(query_batch[1].shape[0]):
                    n_same_speaker = np.sum(
                        np.logical_and(
                            query_batch[1][q_index] == support_batch[1],
                            query_batch[2][q_index] == support_batch[2]))
                    if n_same_speaker > 0:  # easy speakers
                        #                         print("Support set labels:", support_batch[1])
                        #                         print("Support set speaker:", support_batch[2])
                        #                         print("Query set labels:", query_batch[1])
                        #                         print("Query set speaker:", query_batch[2])
                        total_easy_queries += 1
                        if query_batch[1][q_index] == predicted_labels[
                                q_index]:
                            total_easy_correct += 1
                    else:  # distractor speakers
                        total_distractor_queries += 1
                        if query_batch[1][q_index] == predicted_labels[
                                q_index]:
                            total_distractor_correct += 1
                # prediction_originators = support_batch[2][nearest_neighbour_indices]
                # total_easy_correct += np.sum(
                #     np.logical_and(query_batch[1] == predicted_labels,
                #                    query_batch[2] == prediction_originators))
                # total_easy_queries += np.sum(query_batch[2] == prediction_originators)
                # total_distractor_correct += np.sum(
                #     np.logical_and(query_batch[1] == predicted_labels,
                #                    query_batch[2] != prediction_originators))
                # total_distractor_queries += np.sum(query_batch[2] != prediction_originators)
            if episode % log_interval == 0:
                avg_acc = total_correct / total_queries
                ep_message = ("\tFew-shot Test: [Episode: {}/{}]\t"
                              "Average accuracy: {:.7f}".format(
                                  episode, n_episodes, avg_acc))
                if test_invariance:
                    avg_easy_acc = total_easy_correct / total_easy_queries if total_easy_queries != 0 else 0.
                    avg_dist_acc = total_distractor_correct / total_distractor_queries if total_distractor_queries != 0 else 0.
                    ep_message += (
                        "\n\t\tEasy speaker accuracy: {:.7f}\tDistractor "
                        "speaker accuracy: {:.7f}".format(
                            avg_easy_acc, avg_dist_acc))
                    ep_message += (
                        "\n\t\tNum easy speakers: {}\tNum distractor speakers: {}"
                        .format(total_easy_queries, total_distractor_queries))
                logging.info(ep_message)
        # ----------------------------------------------------------------------
        # Print stats:
        # ----------------------------------------------------------------------
        avg_acc = total_correct / total_queries
        few_shot_message = ("Test set (few-shot): Average accuracy: "
                            "{:.5f}".format(avg_acc))
        if test_invariance:
            avg_easy_acc = total_easy_correct / total_easy_queries
            avg_dist_acc = total_distractor_correct / total_distractor_queries
            few_shot_message += (
                "\n\t\tEasy speaker accuracy: {:.5f}\tDistractor "
                "speaker accuracy: {:.5f}".format(avg_easy_acc, avg_dist_acc))
            few_shot_message += (
                "\n\t\tNum easy speakers: {}\tNum distractor speakers: {}".
                format(total_easy_queries, total_distractor_queries))
        logging.info(few_shot_message)
        test_summ_val = sess.run(test_summ,
                                 feed_dict={test_acc_input: avg_acc})
        summary_writer.add_summary(test_summ_val, step)
        summary_writer.flush()
        with open(os.path.join(output_dir, 'test_result.txt'),
                  'w') as res_file:
            res_file.write(few_shot_message)
    # Testing complete
    logging.info("Testing complete.")