Exemple #1
0
def evaluation_loop(fetches, saver, summary_writer, evl_metrics, checkpoint,
                    last_global_step_val):
    """Run the evaluation loop once.
  Args:
    fetches: a dict of tensors to be run within Session.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.
  Returns:
    The global_step used in the latest model.
  """

    global_step_val = -1
    with tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True))) as sess:
        #checkpoint = get_checkpoint()
        if checkpoint:

            print("*" * 20)
            print("*" * 20)
            logging.info("Loading checkpoint for eval: %s", checkpoint)
            # Restores from checkpoint
            saver.restore(sess, checkpoint)
            # Assuming model_checkpoint_path looks something like:
            # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
            global_step_val = os.path.basename(checkpoint).split("-")[-1]

            # Save model
            saver.save(
                sess,
                os.path.join(FLAGS.train_dir, "inference_model",
                             "inference_model"))
        else:
            logging.info("No checkpoint file found.")
            return global_step_val

        if global_step_val == last_global_step_val:
            logging.info(
                "skip this checkpoint global_step_val=%s "
                "(same as the previous one).", global_step_val)
            return global_step_val

        sess.run([tf.local_variables_initializer()])

        # Start the queue runners.
        coord = tf.train.Coordinator()
        try:
            threads = []
            for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(
                    qr.create_threads(sess,
                                      coord=coord,
                                      daemon=True,
                                      start=True))
            logging.info("enter eval_once loop global_step_val = %s. ",
                         global_step_val)

            evl_metrics.clear()

            examples_processed = 0
            while not coord.should_stop():
                batch_start_time = time.time()
                output_data_dict = sess.run(fetches)
                seconds_per_batch = time.time() - batch_start_time
                labels_val = output_data_dict["labels"]
                summary_val = output_data_dict["summary"]
                example_per_second = labels_val.shape[0] / seconds_per_batch
                examples_processed += labels_val.shape[0]

                predictions = output_data_dict["predictions"]
                #breakpoint()
                if FLAGS.segment_labels:
                    # This is a workaround to ignore the unrated labels.
                    predictions *= output_data_dict["label_weights"]
                iteration_info_dict = evl_metrics.accumulate(
                    predictions, labels_val, output_data_dict["loss"])
                iteration_info_dict["examples_per_second"] = example_per_second
                #breakpoint()
                iterinfo = utils.AddGlobalStepSummary(summary_writer,
                                                      global_step_val,
                                                      iteration_info_dict,
                                                      summary_scope="Eval")
                logging.info("examples_processed: %d | %s", examples_processed,
                             iterinfo)

        except tf.errors.OutOfRangeError as e:
            logging.info(
                "Done with batched inference. Now calculating global performance "
                "metrics.")
            # calculate the metrics for the entire epoch
            epoch_info_dict = evl_metrics.get()
            epoch_info_dict["epoch_id"] = global_step_val

            summary_writer.add_summary(summary_val, global_step_val)
            epochinfo = utils.AddEpochSummary(summary_writer,
                                              global_step_val,
                                              epoch_info_dict,
                                              summary_scope="Eval")
            logging.info(epochinfo)
            evl_metrics.clear()
        except Exception as e:  # pylint: disable=broad-except
            logging.info("Unexpected exception: %s", str(e))
            coord.request_stop(e)

        coord.request_stop()
        coord.join(threads, stop_grace_period_secs=10)
        logging.info("Total: examples_processed: %d", examples_processed)

        return global_step_val
Exemple #2
0
def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
                    summary_op, saver, summary_writer, evl_metrics,
                    last_global_step_val, hidden_layer_batch):
  """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """
  global_step_val = -1
  with tf.Session() as sess:
    latest_checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir)
    print(latest_checkpoint)
    if latest_checkpoint:
      logging.info("Loading checkpoint for eval: " + latest_checkpoint)
      # Restores from checkpoint
      saver.restore(sess, latest_checkpoint)
      # Assuming model_checkpoint_path looks something like:
      # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
      global_step_val = latest_checkpoint.split("/")[-1].split("-")[-1]
    else:
      logging.info("No checkpoint file found.")
      return global_step_val

    if global_step_val == last_global_step_val:
      logging.info("skip this checkpoint global_step_val=%s "
                   "(same as the previous one).", global_step_val)
      return global_step_val

    sess.run([tf.local_variables_initializer()])
    # Start the queue runners.
    fetches = [video_id_batch, prediction_batch, label_batch, loss, summary_op, hidden_layer_batch]
    coord = tf.train.Coordinator()
    try:
      threads = []
      for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(qr.create_threads(
            sess, coord=coord, daemon=True,
            start=True))
      logging.info("enter eval_once loop global_step_val = %s. ",
                   global_step_val)

      evl_metrics.clear()

      examples_processed = 0
      while not coord.should_stop():
        batch_start_time = time.time()
        video_id_batch_val, predictions_val, labels_val, loss_val, summary_val, hidden_layer_val = sess.run(
            fetches)

        emb_frames = hidden_layer_val[0,0:FLAGS.embedding_size]
        emb_audio = hidden_layer_val[0, FLAGS.embedding_size:2*FLAGS.embedding_size]
        logging.info(np.sum(np.multiply(emb_frames,emb_audio)))
        # From one random video and its image embedding, return the video_id of the closest audio embedding (besides itself)
        index = np.random.randint(np.size(hidden_layer_val, 0))
        index_similar, max_correlation, original_correlation = get_closest_embedding(index, hidden_layer_val)
        video_id_original = video_id_batch_val[index]
        video_id_similar = video_id_batch_val[index_similar]
        labels_original = np.where(labels_val[index] == 1)
        labels_similar =  np.where(labels_val[index_similar] == 1)
        logging.info("Original video ID and labels: ")
        logging.info(video_id_original)
        logging.info(labels_original)
        logging.info("Closest video ID and labels: ")
        logging.info(video_id_similar)
        logging.info(labels_similar)
        logging.info("Original cosine distance: %.4f: ",original_correlation)
        logging.info("Closest cosine distance: %.4f: ",max_correlation)

        seconds_per_batch = time.time() - batch_start_time
        example_per_second = labels_val.shape[0] / seconds_per_batch
        examples_processed += labels_val.shape[0]

        iteration_info_dict = evl_metrics.accumulate(predictions_val,
                                                     labels_val, loss_val, hidden_layer_val, FLAGS.hits)
        iteration_info_dict["examples_per_second"] = example_per_second

        iterinfo = utils.AddGlobalStepSummary(
            summary_writer,
            global_step_val,
            iteration_info_dict,
            summary_scope="Eval")
        logging.info("examples_processed: %d | %s", examples_processed,
                     iterinfo)
        # This is just to launch an OutOfRangeError when max_steps is reached, to finish the process
        if examples_processed >= (FLAGS.max_batches * FLAGS.batch_size):
            raise ValueError('Time to finish')

    except (tf.errors.OutOfRangeError, ValueError) as e:
      logging.info(
          "Done with batched inference. Now calculating global performance "
          "metrics.")
      # calculate the metrics for the entire epoch
      epoch_info_dict = evl_metrics.get()
      epoch_info_dict["epoch_id"] = global_step_val

      summary_writer.add_summary(summary_val, global_step_val)
      epochinfo = utils.AddEpochSummary(
          summary_writer,
          global_step_val,
          epoch_info_dict,
          summary_scope="Eval")
      logging.info(epochinfo)
      evl_metrics.clear()
    except Exception as e:  # pylint: disable=broad-except
      logging.info("Unexpected exception: " + str(e))
      coord.request_stop(e)

    coord.request_stop()
    coord.join(threads, stop_grace_period_secs=10)

    return global_step_val, video_id_batch_val
Exemple #3
0
def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
                    summary_op, saver, summary_writer, evl_metrics,
                    last_global_step_val):
    """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """

    global_step_val = -1
    with tf.Session() as sess:
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir)
        if latest_checkpoint:
            logging.info("Loading checkpoint for eval: " + latest_checkpoint)
            # Restores from checkpoint
            saver.restore(sess, latest_checkpoint)
            # Assuming model_checkpoint_path looks something like:
            # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
            global_step_val = latest_checkpoint.split("/")[-1].split("-")[-1]
        else:
            logging.info("No checkpoint file found.")
            return global_step_val

        if global_step_val == last_global_step_val:
            logging.info(
                "skip this checkpoint global_step_val=%s "
                "(same as the previous one).", global_step_val)
            return global_step_val

        sess.run([tf.local_variables_initializer()])

        # Start the queue runners.
        fetches = [
            video_id_batch, prediction_batch, label_batch, loss, summary_op
        ]
        coord = tf.train.Coordinator()
        try:
            threads = []
            for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(
                    qr.create_threads(sess,
                                      coord=coord,
                                      daemon=True,
                                      start=True))
            logging.info("enter eval_once loop global_step_val = %s. ",
                         global_step_val)

            evl_metrics.clear()

            examples_processed = 0
            while not coord.should_stop():
                batch_start_time = time.time()
                _, predictions_val, labels_val, loss_val, summary_val = sess.run(
                    fetches)
                seconds_per_batch = time.time() - batch_start_time
                example_per_second = labels_val.shape[0] / seconds_per_batch
                examples_processed += labels_val.shape[0]

                iteration_info_dict = evl_metrics.accumulate(
                    predictions_val, labels_val, loss_val)
                iteration_info_dict["examples_per_second"] = example_per_second

                iterinfo = utils.AddGlobalStepSummary(summary_writer,
                                                      global_step_val,
                                                      iteration_info_dict,
                                                      summary_scope="Eval")
                logging.info("examples_processed: %d | %s", examples_processed,
                             iterinfo)

        except tf.errors.OutOfRangeError as e:
            logging.info(
                "Done with batched inference. Now calculating global performance "
                "metrics.")
            # calculate the metrics for the entire epoch
            epoch_info_dict = evl_metrics.get()
            epoch_info_dict["epoch_id"] = global_step_val

            summary_writer.add_summary(summary_val, global_step_val)
            epochinfo = utils.AddEpochSummary(summary_writer,
                                              global_step_val,
                                              epoch_info_dict,
                                              summary_scope="Eval")
            logging.info(epochinfo)
            evl_metrics.clear()
        except Exception as e:  # pylint: disable=broad-except
            logging.info("Unexpected exception: " + str(e))
            coord.request_stop(e)

        coord.request_stop()
        coord.join(threads, stop_grace_period_secs=10)

        return global_step_val
Exemple #4
0
def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
                    summary_op, saver, summary_writer, evl_metrics):
    """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """

    global_step_val = -1
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        print('number of trainable variables: %d' %
              len(tf.trainable_variables()))
        print('number of global variables: %d' % len(tf.global_variables()))
        if FLAGS.checkpoint_prefix is None:
            raise IOError(
                ("checkpoint_prefix %s is wrong.") % FLAGS.checkpoint_prefix)
        weights_avg_array = averaging_checkpoint(FLAGS.checkpoint_prefix)

        print('number of parameters after averaging: %d' %
              len(weights_avg_array.keys()))

        print('keys in weights_avg_array: ')
        print(weights_avg_array.keys())

        print('variables: ')
        print(tf.trainable_variables())

        for var in tf.global_variables():
            var_name_drop = var.name.replace(':0', '')  ## drop:0
            if var_name_drop in weights_avg_array.keys():
                sess.run(var.assign(weights_avg_array[var_name_drop]))

        # latest_checkpoint = get_latest_checkpoint()
        # if latest_checkpoint:
        #   logging.info("Loading checkpoint for eval: " + latest_checkpoint)
        #   # Restores from checkpoint
        #   saver.restore(sess, latest_checkpoint)
        #   # Assuming model_checkpoint_path looks something like:
        #   # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
        #   global_step_val = os.path.basename(latest_checkpoint).split("-")[-1]

        #   # Save model
        #   saver.save(sess, os.path.join(FLAGS.train_dir, "inference_model"))
        # else:
        #   logging.info("No checkpoint file found.")
        #   return global_step_val

        # if global_step_val == last_global_step_val:
        #   logging.info("skip this checkpoint global_step_val=%s "
        #                "(same as the previous one).", global_step_val)
        #   return global_step_val

        sess.run([tf.local_variables_initializer()])

        # Start the queue runners.
        fetches = [
            video_id_batch, prediction_batch, label_batch, loss, summary_op
        ]
        coord = tf.train.Coordinator()
        try:
            threads = []
            for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(
                    qr.create_threads(sess,
                                      coord=coord,
                                      daemon=True,
                                      start=True))
            logging.info("enter eval_once loop global_step_val = %s. ",
                         global_step_val)

            evl_metrics.clear()

            examples_processed = 0
            while not coord.should_stop():
                batch_start_time = time.time()
                _, predictions_val, labels_val, loss_val, summary_val = sess.run(
                    fetches)
                seconds_per_batch = time.time() - batch_start_time
                example_per_second = labels_val.shape[0] / seconds_per_batch
                examples_processed += labels_val.shape[0]

                iteration_info_dict = evl_metrics.accumulate(
                    predictions_val, labels_val, loss_val)
                iteration_info_dict["examples_per_second"] = example_per_second

                iterinfo = utils.AddGlobalStepSummary(summary_writer,
                                                      global_step_val,
                                                      iteration_info_dict,
                                                      summary_scope="Eval")
                logging.info("examples_processed: %d | %s", examples_processed,
                             iterinfo)

        except tf.errors.OutOfRangeError as e:
            logging.info(
                "Done with batched inference. Now calculating global performance "
                "metrics.")
            # calculate the metrics for the entire epoch
            epoch_info_dict = evl_metrics.get()
            epoch_info_dict["epoch_id"] = global_step_val

            summary_writer.add_summary(summary_val, global_step_val)
            epochinfo = utils.AddEpochSummary(summary_writer,
                                              global_step_val,
                                              epoch_info_dict,
                                              summary_scope="Eval")
            logging.info(epochinfo)
            evl_metrics.clear()
        except Exception as e:  # pylint: disable=broad-except
            logging.info("Unexpected exception: " + str(e))
            coord.request_stop(e)

        coord.request_stop()
        coord.join(threads, stop_grace_period_secs=10)

        return global_step_val
def evaluation_loop(model_nums, train_dirs, video_id_batch, prediction_batch,
                    label_batch, loss, summary_op, saver, summary_writer,
                    evl_metrics):

    global_step_val = -1
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        for i in range(model_nums):
            load_vars(sess, train_dirs[i], "model" + str(i))
        # new load
        saver.save(
            sess,
            os.path.join(FLAGS.Ensemble_Models + FLAGS.ensemble_output_path,
                         "inference_model"))

        sess.run([tf.local_variables_initializer()])

        fetches = [video_id_batch, prediction_batch, label_batch, loss]
        coord = tf.train.Coordinator()
        try:
            threads = []
            for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(
                    qr.create_threads(sess,
                                      coord=coord,
                                      daemon=True,
                                      start=True))
            logging.info("enter eval_once loop global_step_val = %s.",
                         global_step_val)
            evl_metrics.clear()

            examples_processed = 0
            while not coord.should_stop():
                batch_start_time = time.time()
                _, predictions_val, labels_val, loss_val = sess.run(fetches)
                seconds_per_batch = time.time() - batch_start_time
                example_per_second = labels_val.shape[0] / seconds_per_batch
                examples_processed += labels_val.shape[0]

                iteration_info_dict = evl_metrics.accumulate(
                    predictions_val, labels_val, loss_val)
                iteration_info_dict["examples_per_second"] = example_per_second

                iterinfo = utils.AddGlobalStepSummary(summary_writer,
                                                      global_step_val,
                                                      iteration_info_dict,
                                                      summary_scope="Eval")
                logging.info("examples_processed: %d | %s", examples_processed,
                             iterinfo)
        except tf.errors.OutOfRangeError as e:
            logging.info(
                "Done with batched inference. Now calculating global performance metrics."
            )
            epoch_info_dict = evl_metrics.get()
            epoch_info_dict["epoch_id"] = global_step_val

            #summary_writer.add_summary(summary_val,global_step_val)
            epochinfo = utils.AddEpochSummary(summary_writer,
                                              global_step_val,
                                              epoch_info_dict,
                                              summary_scope="Eval")
            logging.info(epochinfo)
            evl_metrics.clear()
        except Exception as e:
            logging.info("Unexpected exception:" + str(e))
            coord.request_stop(e)
        coord.request_stop()
        coord.join(threads, stop_grace_period_secs=10)

        return global_step_val
Exemple #6
0
def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
                    summary_op, saver, summary_writer, evl_metrics,
                    last_global_step_val):
    """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """

    global_step_val = -1
    config = tf.ConfigProto(device_count={'GPU': 0})

    with tf.Session(config=config) as sess:
        latest_checkpoint = get_latest_checkpoint()
        if latest_checkpoint:
            logging.info("Loading checkpoint for eval: " + latest_checkpoint)
            # Restores from checkpoint
            saver.restore(sess, latest_checkpoint)
            # Assuming model_checkpoint_path looks something like:
            # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
            global_step_val = os.path.basename(latest_checkpoint).split(
                "-")[-1].split('_')[-1]

            # Save model
            if FLAGS.force_output_model_name:
                saver.save(sess,
                           os.path.join(FLAGS.train_dir, "inference_model"),
                           write_meta_graph=False)

                selected_collections = [
                    'global_step', 'input_batch', 'input_batch_raw', 'labels',
                    'local_variables', 'loss', 'model_variables', 'num_frames',
                    'predictions', 'regularization_losses', 'summaries',
                    'summary_op', 'trainable_variables', 'variables'
                ]
                tf.train.export_meta_graph(
                    filename=os.path.join(FLAGS.train_dir,
                                          "inference_model.meta"),
                    collection_list=selected_collections)

            elif "inference_model" in FLAGS.checkpoint_file:
                if "ensemble" in FLAGS.checkpoint_file:
                    saver.save(
                        sess,
                        os.path.join(
                            FLAGS.train_dir,
                            FLAGS.checkpoint_file.replace(
                                'ensemble',
                                'ensemble_' + str(FLAGS.ensemble_wts).replace(
                                    ',', '').replace(' ', '_').replace(
                                        '.', '').replace('[', '').replace(
                                            ']', ''))))
            else:
                if "avg" not in FLAGS.checkpoint_file:
                    saver.save(
                        sess,
                        os.path.join(
                            FLAGS.train_dir, "inference_model_" +
                            latest_checkpoint.split('-')[-1]))
                else:
                    saver.save(
                        sess,
                        os.path.join(
                            FLAGS.train_dir,
                            "inference_model_" + FLAGS.checkpoint_file))
        else:
            logging.info("No checkpoint file found.")
            return global_step_val

        if global_step_val == last_global_step_val:
            logging.info(
                "skip this checkpoint global_step_val=%s "
                "(same as the previous one).", global_step_val)
            return global_step_val

        if FLAGS.create_meta_only:
            return global_step_val

        sess.run([tf.local_variables_initializer()])

        # Start the queue runners.
        fetches = [
            video_id_batch, prediction_batch, label_batch, loss, summary_op
        ]
        coord = tf.train.Coordinator()
        try:
            threads = []
            for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(
                    qr.create_threads(sess,
                                      coord=coord,
                                      daemon=True,
                                      start=True))
            logging.info("enter eval_once loop global_step_val = %s. ",
                         global_step_val)

            evl_metrics.clear()

            examples_processed = 0
            while not coord.should_stop():
                batch_start_time = time.time()
                _, predictions_val, labels_val, loss_val, summary_val = sess.run(
                    fetches)
                seconds_per_batch = time.time() - batch_start_time
                example_per_second = labels_val.shape[0] / seconds_per_batch
                examples_processed += labels_val.shape[0]

                iteration_info_dict = evl_metrics.accumulate(
                    predictions_val, labels_val, loss_val)
                iteration_info_dict["examples_per_second"] = example_per_second

                iterinfo = utils.AddGlobalStepSummary(summary_writer,
                                                      global_step_val,
                                                      iteration_info_dict,
                                                      summary_scope="Eval")
                logging.info("examples_processed: %d | %s", examples_processed,
                             iterinfo)

        except tf.errors.OutOfRangeError as e:
            logging.info(
                "Done with batched inference. Now calculating global performance "
                "metrics.")
            # calculate the metrics for the entire epoch
            epoch_info_dict = evl_metrics.get()
            epoch_info_dict["epoch_id"] = global_step_val

            summary_writer.add_summary(summary_val, global_step_val)
            epochinfo = utils.AddEpochSummary(summary_writer,
                                              global_step_val,
                                              epoch_info_dict,
                                              summary_scope="Eval")
            logging.info(epochinfo)
            evl_metrics.clear()
        except Exception as e:  # pylint: disable=broad-except
            logging.info("Unexpected exception: " + str(e))
            coord.request_stop(e)

        coord.request_stop()
        coord.join(threads, stop_grace_period_secs=10)

        return global_step_val
Exemple #7
0
def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
                    summary_op, saver, summary_writer, evl_metrics,
                    last_global_step_val, ema_tensors):
  """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """

  global_step_val = -1
  latest_checkpoint = get_latest_checkpoint()

  with tf.Session() as sess:

    if latest_checkpoint:
      logging.info("Loading checkpoint for eval: " + latest_checkpoint)
      # Restores from checkpoint
      saver.restore(sess, latest_checkpoint)
      # Assuming model_checkpoint_path looks something like:
      # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
      global_step_val = os.path.basename(latest_checkpoint).split("-")[-1]

      if FLAGS.use_EMA:
        assert len(ema_tensors) > 0, "Tensors got lost."
        logging.info("####################")
        logging.info("USING EMA VARIABLES.")
        logging.info("####################")

        reader = pywrap_tensorflow.NewCheckpointReader(latest_checkpoint)
        global_vars = tf.global_variables()

        for stensor in ema_tensors:
          destination_t = [x for x in global_vars if x.name == stensor.replace("/ExponentialMovingAverage:", ":")]
          assert len(destination_t) == 1
          destination_t = destination_t[0]
          ema_source = reader.get_tensor(stensor.split(":")[0])
          # Session to take care of
          destination_t.load(ema_source, session=sess)

      # Save model
      saver.save(sess, os.path.join(FLAGS.train_dir, "inference_model"))
      if FLAGS.build_only:
          logging.info("Inference graph built. Existing now.")
          exit()
    else:
      logging.info("No checkpoint file found.")
      return global_step_val

    if global_step_val == last_global_step_val:
      logging.info("skip this checkpoint global_step_val=%s "
                   "(same as the previous one).", global_step_val)
      return global_step_val

    sess.run([tf.local_variables_initializer()])

    # Start the queue runners.
    fetches = [video_id_batch, prediction_batch, label_batch, loss, summary_op]
    coord = tf.train.Coordinator()
    try:
      threads = []
      for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(qr.create_threads(
            sess, coord=coord, daemon=True,
            start=True))
      logging.info("enter eval_once loop global_step_val = %s. ",
                   global_step_val)

      evl_metrics.clear()

      examples_processed = 0
      while not coord.should_stop():
        batch_start_time = time.time()
        _, predictions_val, labels_val, loss_val, summary_val = sess.run(
            fetches)
        seconds_per_batch = time.time() - batch_start_time
        example_per_second = labels_val.shape[0] / seconds_per_batch
        examples_processed += labels_val.shape[0]

        iteration_info_dict = evl_metrics.accumulate(predictions_val,
                                                     labels_val, loss_val)
        iteration_info_dict["examples_per_second"] = example_per_second

        iterinfo = utils.AddGlobalStepSummary(
            summary_writer,
            global_step_val,
            iteration_info_dict,
            summary_scope="Eval")
        logging.info("examples_processed: %d | %s", examples_processed,
                     iterinfo)

    except tf.errors.OutOfRangeError as e:
      logging.info(
          "Done with batched inference. Now calculating global performance "
          "metrics.")
      # calculate the metrics for the entire epoch
      epoch_info_dict = evl_metrics.get()
      epoch_info_dict["epoch_id"] = global_step_val

      summary_writer.add_summary(summary_val, global_step_val)
      epochinfo = utils.AddEpochSummary(
          summary_writer,
          global_step_val,
          epoch_info_dict,
          summary_scope="Eval")
      logging.info(epochinfo)
      with open ('evallog.txt', 'a') as f: f.write(epochinfo + '\n')
      evl_metrics.clear()
    except Exception as e:  # pylint: disable=broad-except
      logging.info("Unexpected exception: " + str(e))
      coord.request_stop(e)

    coord.request_stop()
    coord.join(threads, stop_grace_period_secs=10)

    return global_step_val
    def eval_loop(self, last_global_step_val, evl_metrics):
        """Run the evaluation loop once.

    Args:
      last_global_step_val: the global step used in the previous evaluation.

    Returns:
      The global_step used in the latest model.
    """
        latest_checkpoint, global_step_val = self.get_checkpoint(
            last_global_step_val)
        logging.info("latest_checkpoint: {}".format(latest_checkpoint))

        if latest_checkpoint is None or global_step_val == last_global_step_val:
            time.sleep(self.wait)
            return last_global_step_val

        config = tf.ConfigProto(allow_soft_placement=True)
        with tf.Session(config=config) as sess:
            logging.info(
                "Loading checkpoint for eval: {}".format(latest_checkpoint))

            # Restores from checkpoint
            self.saver.restore(sess, latest_checkpoint)
            sess.run(tf.local_variables_initializer())

            evl_metrics.clear()

            train_gpu = FLAGS.train_num_gpu
            train_batch_size = FLAGS.train_batch_size
            n_train_files = self.reader.n_train_files
            if train_gpu:
                epoch = ((global_step_val * train_batch_size * train_gpu) /
                         n_train_files)
            else:
                epoch = ((global_step_val * train_batch_size) / n_train_files)

            examples_processed = 0
            while True:
                try:
                    batch_start_time = time.time()

                    fetches = [
                        self.logits, self.labels, self.labels_losses,
                        self.summary_op
                    ]
                    logits_val, labels_val, loss_val, summary_val = sess.run(
                        fetches)
                    seconds_per_batch = time.time() - batch_start_time
                    examples_per_second = self.batch_size / seconds_per_batch
                    examples_processed += self.batch_size

                    iteration_info_dict = evl_metrics.accumulate(
                        logits_val, labels_val, loss_val)
                    iteration_info_dict[
                        "examples_per_second"] = examples_per_second

                    iterinfo = utils.AddGlobalStepSummary(self.summary_writer,
                                                          global_step_val,
                                                          iteration_info_dict,
                                                          summary_scope="Eval")
                    logging.info("examples_processed: %d | %s",
                                 examples_processed, iterinfo)

                except tf.errors.OutOfRangeError as e:
                    logging.info(
                        "Done with batched inference. Now calculating global performance "
                        "metrics.")
                    # calculate the metrics for the entire epoch
                    epoch_info_dict = evl_metrics.get()
                    epoch_info_dict["epoch_id"] = global_step_val

                    self.summary_writer.add_summary(summary_val,
                                                    global_step_val)
                    epochinfo = utils.AddEpochSummary(self.summary_writer,
                                                      global_step_val,
                                                      epoch_info_dict,
                                                      summary_scope="Eval")
                    logging.info(epochinfo)
                    evl_metrics.clear()

                    if FLAGS.stopped_at_n:
                        self.counter += 1
                    break

                except Exception as e:
                    logging.info("Unexpected exception: {}".format(e))
                    sys.exit(0)

            return global_step_val
Exemple #9
0
  def run(self, start_new_model=False):
    """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
    if self.is_master and start_new_model:
      self.remove_training_directory(self.train_dir)

    target, device_fn = self.start_server_if_distributed()

    meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

    with tf.Graph().as_default() as graph:

      if meta_filename:
        saver = self.recover_model(meta_filename)

      with tf.device(device_fn):
        if not meta_filename:
          saver = self.build_model(self.model, self.reader)

        global_step = tf.get_collection("global_step")[0]
        loss = tf.get_collection("loss")[0]
        predictions = tf.get_collection("predictions")[0]
        labels = tf.get_collection("labels")[0]
        train_op = tf.get_collection("train_op")[0]
        init_op = tf.global_variables_initializer()

    sv = tf.train.Supervisor(
        graph,
        logdir=self.train_dir,
        init_op=init_op,
        is_chief=self.is_master,
        global_step=global_step,
        save_model_secs=15 * 60,
        save_summaries_secs=120,
        saver=saver)

    logging.info("%s: Starting managed session.", task_as_string(self.task))
    with sv.managed_session(target, config=self.config) as sess:
      try:
        logging.info("%s: Entering training loop.", task_as_string(self.task))
        while (not sv.should_stop()) and (not self.max_steps_reached):
          batch_start_time = time.time()
          _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
              [train_op, global_step, loss, predictions, labels])
          seconds_per_batch = time.time() - batch_start_time
          examples_per_second = labels_val.shape[0] / seconds_per_batch

          if self.max_steps and self.max_steps <= global_step_val:
            self.max_steps_reached = True

          if self.is_master and global_step_val % 10 == 0 and self.train_dir:
            eval_start_time = time.time()
            hit_at_one = eval_util.calculate_hit_at_one(predictions_val, labels_val)
            perr = eval_util.calculate_precision_at_equal_recall_rate(predictions_val,
                                                                      labels_val)
            gap = eval_util.calculate_gap(predictions_val, labels_val)
            eval_end_time = time.time()
            eval_time = eval_end_time - eval_start_time

            logging.info("training step " + str(global_step_val) + " | Loss: " + ("%.2f" % loss_val) +
              " Examples/sec: " + ("%.2f" % examples_per_second) + " | Hit@1: " +
              ("%.2f" % hit_at_one) + " PERR: " + ("%.2f" % perr) +
              " GAP: " + ("%.2f" % gap))

            info_dict = {"hit_at_one": hit_at_one,
                         "perr": perr,
                         "gap": gap,
                         "loss": loss_val,
                         "examples_per_second": examples_per_second,
                         }

            utils.AddGlobalStepSummary(sv.summary_writer,
                                       global_step_val,
                                       info_dict)

            # Exporting the model every x steps
            time_to_export = ((self.last_model_export_step == 0) or
                (global_step_val - self.last_model_export_step
                 >= self.export_model_steps))

            if self.is_master and time_to_export:
              self.export_model(global_step_val, sv.saver, sv.save_path, sess)
              self.last_model_export_step = global_step_val
          else:
            logging.info("training step " + str(global_step_val) + " | Loss: " +
              ("%.2f" % loss_val) + " Examples/sec: " + ("%.2f" % examples_per_second))
      except tf.errors.OutOfRangeError:
        logging.info("%s: Done training -- epoch limit reached.",
                     task_as_string(self.task))

    logging.info("%s: Exited training loop.", task_as_string(self.task))
    sv.Stop()
def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
                    summary_op, saver, summary_writer, evl_metrics,
                    last_global_step_val):
  """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """

  global_step_val = -1
  with tf.Session(config=tf.ConfigProto(log_device_placement=False)) as sess:
#     latest_checkpoint = get_latest_checkpoint()
#     if latest_checkpoint:
#       logging.info("Loading checkpoint for eval: " + latest_checkpoint)
#       # Restores from checkpoint
#       saver.restore(sess, latest_checkpoint)
#       # Assuming model_checkpoint_path looks something like:
#       # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
#       global_step_val = os.path.basename(latest_checkpoint).split("-")[-1]

#       # Save model
#       saver.save(sess, os.path.join(FLAGS.train_dir, "inference_model"))
#     else:
#       logging.info("No checkpoint file found.")
#       return global_step_val

#     if global_step_val == last_global_step_val:
#       logging.info("skip this checkpoint global_step_val=%s "
#                    "(same as the previous one).", global_step_val)
#       return global_step_val

    sess.run([tf.local_variables_initializer()])

    # Start the queue runners.
    fetches = [video_id_batch, prediction_batch, label_batch, loss, summary_op]
    coord = tf.train.Coordinator()
    
    
    # output results
    start_time = time.time()
    video_ids = []
    video_labels = []
    video_features = []
    filenum = 0
    num_examples_processed = 0
    total_num_examples_processed = 0
    
    # output prediction dir
    directory = FLAGS.output_dir
    if directory != '':
      if not os.path.exists(directory):
          os.makedirs(directory)
      else:
          raise IOError("Output path exists! path='" + directory + "'")
    
    try:
      threads = []
      for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(qr.create_threads(
            sess, coord=coord, daemon=True,
            start=True))
      logging.info("enter eval_once loop global_step_val = %s. ",
                   global_step_val)

      evl_metrics.clear()

      examples_processed = 0
      while not coord.should_stop():
        batch_start_time = time.time()
        ids_val, predictions_val, labels_val, loss_val, summary_val = sess.run(
            fetches)
        seconds_per_batch = time.time() - batch_start_time
        example_per_second = labels_val.shape[0] / seconds_per_batch
        examples_processed += labels_val.shape[0]

        iteration_info_dict = evl_metrics.accumulate(predictions_val,
                                                     labels_val, loss_val)
        iteration_info_dict["examples_per_second"] = example_per_second

        iterinfo = utils.AddGlobalStepSummary(
            summary_writer,
            global_step_val,
            iteration_info_dict,
            summary_scope="Eval")
        logging.info("examples_processed: %d | %s", examples_processed,
                     iterinfo)
        
        # save predictions
        if directory != '':
          video_ids.append(ids_val)
          video_labels.append(labels_val)
          video_features.append(predictions_val)
          num_examples_processed += len(ids_val)

          if num_examples_processed >= FLAGS.file_size:
            assert num_examples_processed==FLAGS.file_size, "num_examples_processed should be equal to %d"%FLAGS.file_size
            video_ids = np.concatenate(video_ids, axis=0)
            video_labels = np.concatenate(video_labels, axis=0)
            video_features = np.concatenate(video_features, axis=0)
            write_to_record(video_ids, video_labels, video_features, filenum, num_examples_processed)

            video_ids = []
            video_labels = []
            video_features = []
            filenum += 1
            total_num_examples_processed += num_examples_processed

            now = time.time()
            logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
            num_examples_processed = 0

    except tf.errors.OutOfRangeError as e:
      logging.info(
          "Done with batched inference. Now calculating global performance "
          "metrics.")
      # calculate the metrics for the entire epoch
      epoch_info_dict = evl_metrics.get()
      epoch_info_dict["epoch_id"] = global_step_val

      summary_writer.add_summary(summary_val, global_step_val)
      epochinfo = utils.AddEpochSummary(
          summary_writer,
          global_step_val,
          epoch_info_dict,
          summary_scope="Eval")
      logging.info(epochinfo)
      evl_metrics.clear()
      
      # save prediction
      if directory != '':
        # if ids_val is not None:
        #   video_ids.append(ids_val)
        #   video_labels.append(labels_val)
        #   video_features.append(predictions_val)
        #   num_examples_processed += len(ids_val)

        if 0 < num_examples_processed <= FLAGS.file_size:
          video_ids = np.concatenate(video_ids, axis=0)
          video_labels = np.concatenate(video_labels, axis=0)
          video_features = np.concatenate(video_features, axis=0)
          write_to_record(video_ids, video_labels, video_features, filenum, num_examples_processed)
          total_num_examples_processed += num_examples_processed

          now = time.time()
          logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
          num_examples_processed = 0

        logging.info("Done with inference. %d samples was written to %s" % (total_num_examples_processed, FLAGS.output_dir))
    except Exception as e:  # pylint: disable=broad-except
      logging.info("Unexpected exception: " + str(e))
      coord.request_stop(e)

    coord.request_stop()
    coord.join(threads, stop_grace_period_secs=10)

    return global_step_val
Exemple #11
0
def evaluation_loop(self, saver, model_ckpt_path):
    global_step_val = model_ckpt_path.split("/")[-1].split("-")[-1]
    evl_metrics = eval_util.EvaluationMetrics(self.model.num_classes,
                                              self.config.top_k)

    # summary_writer = tf.summary.FileWriter(
    # self.train_dir, graph=tf.get_default_graph())
    summary_writer = None

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = 0.9
    video_ids = []
    output_scores = 1  # 1->output score, 2-> output features
    if output_scores == 1:
        model_id = model_ckpt_path.split(
            "/")[-2] + "-" + model_ckpt_path.split("-")[-1]
        # num_insts = 4906660
        # stage = "train"
        num_insts = 1401828
        stage = "validate"
        # num_insts = 700640
        # stage = "test"
        video_ids_pkl_path = "/data/D2DCRC/linchao/YT/scores/{}.{}.pkl".format(
            model_id, stage)
        # video_ids_pkl_path = pkl.load(open("/data/D2DCRC/linchao/YT/{}_vids_dict.pkl".format(stage)))
        # log_path = "/data/D2DCRC/linchao/YT/scores/{}.{}.touch".format(model_id, stage)
        pred_out = h5py.File(
            "/data/D2DCRC/linchao/YT/scores/{}.{}.h5".format(model_id, stage),
            "w")
        pred_dataset = pred_out.create_dataset('scores',
                                               shape=(num_insts,
                                                      self.model.num_classes),
                                               dtype=np.float32)
    elif output_scores == 2:
        output_prefix = "/data/uts700/linchao/yt8m/data/555_netvlad/train"
        tfrecord_cntr = 0
        output_filename = "{}/{}.tfrecord".format(output_prefix,
                                                  tfrecord_cntr / 1200)
        tfrecord_writer = tf.python_io.TFRecordWriter(output_filename)

    with tf.Session(config=sess_config) as sess:
        saver.restore(sess, model_ckpt_path)
        sess.run([tf.local_variables_initializer()])

        # Start the queue runners.
        coord = tf.train.Coordinator()
        try:
            threads = []
            for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(
                    qr.create_threads(sess,
                                      coord=coord,
                                      daemon=True,
                                      start=True))
            logging.info("enter eval_once loop global_step_val = %s. ",
                         global_step_val)

            evl_metrics.clear()

            examples_processed = 0
            while not coord.should_stop():
                batch_start_time = time.time()
                res = sess.run(self.feed_out)
                seconds_per_batch = time.time() - batch_start_time
                example_per_second = res["dense_labels"].shape[
                    0] / seconds_per_batch
                examples_processed += res["dense_labels"].shape[0]
                predictions = res["predictions"]
                video_id = res["video_id"].tolist()
                if output_scores == 1:
                    # for i in xrange(len(video_id)):
                    # pred_dataset[video_ids_pkl_path[video_id[i]], :] = predictions[i]
                    pred_dataset[len(video_ids):len(video_ids) +
                                 len(video_id), :] = predictions
                    video_ids += video_id
                elif output_scores == 2:
                    for i in xrange(len(video_id)):
                        sparse_label = np.array(res["dense_labels"][i],
                                                dtype=np.int32)
                        sparse_label = np.where(sparse_label == 1)[0].tolist()
                        feat = res["feats"][i]
                        print(sparse_label, video_id[i])
                        example = matrix_to_tfexample(feat,
                                                      labels=sparse_label,
                                                      video_id=video_id[i])
                        tfrecord_writer.write(example.SerializeToString())
                        if tfrecord_cntr % 1200 == 0:
                            tfrecord_writer.close()
                            output_filename = "{}/{}.tfrecord".format(
                                output_prefix, tfrecord_cntr / 1200)
                            tfrecord_writer = tf.python_io.TFRecordWriter(
                                output_filename)
                        tfrecord_cntr += 1

                if type(predictions) == list:
                    predictions = eval_util.transform_preds(self, predictions)

                iteration_info_dict = evl_metrics.accumulate(
                    predictions, res["dense_labels"], res["loss"])
                iteration_info_dict["examples_per_second"] = example_per_second

                gap = eval_util.calculate_gap(predictions, res["dense_labels"])
                iterinfo = utils.AddGlobalStepSummary(summary_writer,
                                                      global_step_val,
                                                      iteration_info_dict,
                                                      summary_scope="Eval")
                '''
        p = [str(_) for _ in np.where(res["dense_labels"][0, :] > 0)[0].tolist()]
        print_labels = "+".join(p)
        p = np.argsort(res["predictions"][0, :])[-20:]
        p = np.sort(p).tolist()
        p = [str(_) for _ in p]
        pred_labels = "+".join(p)
        logging.info("vid: %s; gap: %s; labels %s; predictions %s" % (
            res['video_id'][0], gap, print_labels, pred_labels))
        '''
                logging.info("examples_processed: %d | %s | gap: %s",
                             examples_processed, iterinfo, gap)

        except tf.errors.OutOfRangeError as e:
            logging.info(
                "Done with batched inference. Now calculating global performance "
                "metrics.")
            if output_scores == 1:
                pred_out.close()
                # with open(log_path, 'w') as fout:
                # print>>fout , "Done"
                pkl.dump(video_ids, open(video_ids_pkl_path, "w"))
            elif output_scores == 2:
                tfrecord_writer.close()
            else:
                # calculate the metrics for the entire epoch
                epoch_info_dict = evl_metrics.get()
                epoch_info_dict["epoch_id"] = global_step_val
                if summary_writer:
                    summary_writer.add_summary(res["summary"], global_step_val)
                epochinfo = utils.AddEpochSummary(summary_writer,
                                                  global_step_val,
                                                  epoch_info_dict,
                                                  summary_scope="Eval")
                logging.info(epochinfo)
                evl_metrics.clear()
        except Exception as e:  # pylint: disable=broad-except
            logging.info("Unexpected exception: " + str(e))
            coord.request_stop(e)

        coord.request_stop()
        coord.join(threads, stop_grace_period_secs=10)