コード例 #1
0
def evaluation_loop(video_id_batch,
                    prediction_batch,
                    label_batch,
                    summary_op,
                    saver,
                    summary_writer,
                    evl_metrics,
                    last_global_step_val,
                    total_ens_times=1):
    """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """
    global_step_val = -1
    final_preds = []

    final_df = []

    for e in range(total_ens_times):
        print("\n\nEnsemble round ", e, "\n\n")

        this_round_preds = []
        this_round_labels = []
        this_round_loss = []
        this_round_video_ids = []

        with tf.Session() as sess:
            if FLAGS.load_this_checkpoint is "":
                latest_checkpoint = get_latest_checkpoint()
            else:
                latest_checkpoint = FLAGS.train_dir + '/model.ckpt-' + FLAGS.load_this_checkpoint
            if latest_checkpoint:
                logging.info("Loading checkpoint for eval: " +
                             latest_checkpoint)
                # Restores from checkpoint
                saver.restore(sess, latest_checkpoint)
                # Assuming model_checkpoint_path looks something like:
                # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
                global_step_val = os.path.basename(latest_checkpoint).split(
                    "-")[-1]

                # Save model
                saver.save(sess,
                           os.path.join(FLAGS.train_dir, "inference_model"))
            else:
                logging.info("No checkpoint file found.")
                return global_step_val

            if global_step_val == last_global_step_val:
                logging.info(
                    "skip this checkpoint global_step_val=%s "
                    "(same as the previous one).", global_step_val)
                return global_step_val

            sess.run([tf.local_variables_initializer()])

            # Start the queue runners.
            fetches = [
                video_id_batch, prediction_batch, label_batch, summary_op
            ]
            #fetches = [prediction_batch, label_batch, loss]

            coord = tf.train.Coordinator()
            try:
                threads = []
                for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                    threads.extend(
                        qr.create_threads(sess,
                                          coord=coord,
                                          daemon=True,
                                          start=True))
                logging.info("enter eval_once loop global_step_val = %s. ",
                             global_step_val)

                examples_processed = 0
                while not coord.should_stop():

                    batch_start_time = time.time()
                    video_id_batch, predictions_val, labels_val, summary_val = sess.run(
                        fetches)

                    ##### mapping 2017 takes place here #####
                    print("\n\nLabel dim: {}\n\n".format(labels_val.shape))

                    labels_val_bool = labels_val[:, -1] == 0

                    print(labels_val_bool)

                    video_id_batch = video_id_batch[labels_val_bool]
                    predictions_val = predictions_val[labels_val_bool]
                    labels_val = labels_val[labels_val_bool]

                    labels_val = labels_val[:, :-1]

                    #if mapping_2017 is not None:
                    #    labels_val = np.dot(labels_val_2017, mapping_2017)

                    #fake_ids = np.array(['aaa'] * labels_val.shape[0]).reshape(-1, 1)
                    fake_loss = np.array([0.] * labels_val.shape[0]).reshape(
                        -1, 1)

                    this_round_preds.append(predictions_val)
                    this_round_labels.append(labels_val)
                    this_round_loss.append(fake_loss)
                    this_round_video_ids.append(video_id_batch)

                    seconds_per_batch = time.time() - batch_start_time
                    example_per_second = labels_val.shape[0] / seconds_per_batch
                    examples_processed += labels_val.shape[0]

                    #iteration_info_dict = evl_metrics.accumulate(predictions_val,
                    #                                             labels_val, loss_val)
                    #iteration_info_dict["examples_per_second"] = example_per_second

                    #iterinfo = utils.AddGlobalStepSummary(
                    #    summary_writer,
                    #    global_step_val,
                    #    iteration_info_dict,
                    #    summary_scope="Eval")
                    logging.info("examples_processed: %d ", examples_processed)

            except tf.errors.OutOfRangeError as e:
                logging.info(
                    "Done with batched inference for this ensemble round. ")

            except Exception as e:  # pylint: disable=broad-except
                logging.info("Unexpected exception: " + str(e))
                coord.request_stop(e)

            coord.request_stop()
            coord.join(threads, stop_grace_period_secs=10)

        temp_preds = pd.DataFrame(np.concatenate(this_round_preds, axis=0))
        temp_preds.columns = ['preds'] * 3862
        temp_labels = pd.DataFrame(np.concatenate(this_round_labels, axis=0))
        temp_labels.columns = ['labels'] * 3862
        temp_df = pd.concat([temp_labels, temp_preds], axis=1)
        temp_df.index = np.concatenate(this_round_video_ids, axis=0)

        if FLAGS.save_preds:
            temp_df.to_csv(os.path.join(FLAGS.train_dir,
                                        'eval_predictions.csv'),
                           index=True)

        final_df.append(temp_df)

        final_index = final_df[0].index
        final_labels = final_df[0]['labels'].values
        if len(final_df) == 1:
            final_preds = final_df[0]['preds'].values
        else:
            final_preds = np.concatenate([
                np.expand_dims(temp_df.ix[final_index]['preds'].values, 2)
                for temp_df in final_df
            ],
                                         axis=2)
            final_preds = hmean(final_preds, axis=2)

        fake_loss_zeros = np.zeros((final_df[0].shape[0], 1))

        evl_metrics.clear()
        iteration_info_dict = evl_metrics.accumulate(final_preds, final_labels,
                                                     fake_loss_zeros)
        # calculate the metrics for the entire epoch
        epoch_info_dict = evl_metrics.get()
        epoch_info_dict["epoch_id"] = global_step_val

        summary_writer.add_summary(summary_val, global_step_val)
        epochinfo = utils.AddEpochSummary(summary_writer,
                                          global_step_val,
                                          epoch_info_dict,
                                          summary_scope="Eval")
        logging.info(epochinfo)
        evl_metrics.clear()

    ## end of for loop of ensembling

    ## take the mean of all preds, sorted by index of first df

    return global_step_val
コード例 #2
0
def evaluation_loop(fetches, saver, summary_writer, evl_metrics, checkpoint,
                    last_global_step_val):
    """Run the evaluation loop once.
  Args:
    fetches: a dict of tensors to be run within Session.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.
  Returns:
    The global_step used in the latest model.
  """

    global_step_val = -1
    with tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True))) as sess:
        #checkpoint = get_checkpoint()
        if checkpoint:

            print("*" * 20)
            print("*" * 20)
            logging.info("Loading checkpoint for eval: %s", checkpoint)
            # Restores from checkpoint
            saver.restore(sess, checkpoint)
            # Assuming model_checkpoint_path looks something like:
            # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
            global_step_val = os.path.basename(checkpoint).split("-")[-1]

            # Save model
            saver.save(
                sess,
                os.path.join(FLAGS.train_dir, "inference_model",
                             "inference_model"))
        else:
            logging.info("No checkpoint file found.")
            return global_step_val

        if global_step_val == last_global_step_val:
            logging.info(
                "skip this checkpoint global_step_val=%s "
                "(same as the previous one).", global_step_val)
            return global_step_val

        sess.run([tf.local_variables_initializer()])

        # Start the queue runners.
        coord = tf.train.Coordinator()
        try:
            threads = []
            for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(
                    qr.create_threads(sess,
                                      coord=coord,
                                      daemon=True,
                                      start=True))
            logging.info("enter eval_once loop global_step_val = %s. ",
                         global_step_val)

            evl_metrics.clear()

            examples_processed = 0
            while not coord.should_stop():
                batch_start_time = time.time()
                output_data_dict = sess.run(fetches)
                seconds_per_batch = time.time() - batch_start_time
                labels_val = output_data_dict["labels"]
                summary_val = output_data_dict["summary"]
                example_per_second = labels_val.shape[0] / seconds_per_batch
                examples_processed += labels_val.shape[0]

                predictions = output_data_dict["predictions"]
                #breakpoint()
                if FLAGS.segment_labels:
                    # This is a workaround to ignore the unrated labels.
                    predictions *= output_data_dict["label_weights"]
                iteration_info_dict = evl_metrics.accumulate(
                    predictions, labels_val, output_data_dict["loss"])
                iteration_info_dict["examples_per_second"] = example_per_second
                #breakpoint()
                iterinfo = utils.AddGlobalStepSummary(summary_writer,
                                                      global_step_val,
                                                      iteration_info_dict,
                                                      summary_scope="Eval")
                logging.info("examples_processed: %d | %s", examples_processed,
                             iterinfo)

        except tf.errors.OutOfRangeError as e:
            logging.info(
                "Done with batched inference. Now calculating global performance "
                "metrics.")
            # calculate the metrics for the entire epoch
            epoch_info_dict = evl_metrics.get()
            epoch_info_dict["epoch_id"] = global_step_val

            summary_writer.add_summary(summary_val, global_step_val)
            epochinfo = utils.AddEpochSummary(summary_writer,
                                              global_step_val,
                                              epoch_info_dict,
                                              summary_scope="Eval")
            logging.info(epochinfo)
            evl_metrics.clear()
        except Exception as e:  # pylint: disable=broad-except
            logging.info("Unexpected exception: %s", str(e))
            coord.request_stop(e)

        coord.request_stop()
        coord.join(threads, stop_grace_period_secs=10)
        logging.info("Total: examples_processed: %d", examples_processed)

        return global_step_val
コード例 #3
0
def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
                    summary_op, saver, summary_writer, evl_metrics,
                    last_global_step_val, hidden_layer_batch):
  """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """
  global_step_val = -1
  with tf.Session() as sess:
    latest_checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir)
    print(latest_checkpoint)
    if latest_checkpoint:
      logging.info("Loading checkpoint for eval: " + latest_checkpoint)
      # Restores from checkpoint
      saver.restore(sess, latest_checkpoint)
      # Assuming model_checkpoint_path looks something like:
      # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
      global_step_val = latest_checkpoint.split("/")[-1].split("-")[-1]
    else:
      logging.info("No checkpoint file found.")
      return global_step_val

    if global_step_val == last_global_step_val:
      logging.info("skip this checkpoint global_step_val=%s "
                   "(same as the previous one).", global_step_val)
      return global_step_val

    sess.run([tf.local_variables_initializer()])
    # Start the queue runners.
    fetches = [video_id_batch, prediction_batch, label_batch, loss, summary_op, hidden_layer_batch]
    coord = tf.train.Coordinator()
    try:
      threads = []
      for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(qr.create_threads(
            sess, coord=coord, daemon=True,
            start=True))
      logging.info("enter eval_once loop global_step_val = %s. ",
                   global_step_val)

      evl_metrics.clear()

      examples_processed = 0
      while not coord.should_stop():
        batch_start_time = time.time()
        video_id_batch_val, predictions_val, labels_val, loss_val, summary_val, hidden_layer_val = sess.run(
            fetches)

        emb_frames = hidden_layer_val[0,0:FLAGS.embedding_size]
        emb_audio = hidden_layer_val[0, FLAGS.embedding_size:2*FLAGS.embedding_size]
        logging.info(np.sum(np.multiply(emb_frames,emb_audio)))
        # From one random video and its image embedding, return the video_id of the closest audio embedding (besides itself)
        index = np.random.randint(np.size(hidden_layer_val, 0))
        index_similar, max_correlation, original_correlation = get_closest_embedding(index, hidden_layer_val)
        video_id_original = video_id_batch_val[index]
        video_id_similar = video_id_batch_val[index_similar]
        labels_original = np.where(labels_val[index] == 1)
        labels_similar =  np.where(labels_val[index_similar] == 1)
        logging.info("Original video ID and labels: ")
        logging.info(video_id_original)
        logging.info(labels_original)
        logging.info("Closest video ID and labels: ")
        logging.info(video_id_similar)
        logging.info(labels_similar)
        logging.info("Original cosine distance: %.4f: ",original_correlation)
        logging.info("Closest cosine distance: %.4f: ",max_correlation)

        seconds_per_batch = time.time() - batch_start_time
        example_per_second = labels_val.shape[0] / seconds_per_batch
        examples_processed += labels_val.shape[0]

        iteration_info_dict = evl_metrics.accumulate(predictions_val,
                                                     labels_val, loss_val, hidden_layer_val, FLAGS.hits)
        iteration_info_dict["examples_per_second"] = example_per_second

        iterinfo = utils.AddGlobalStepSummary(
            summary_writer,
            global_step_val,
            iteration_info_dict,
            summary_scope="Eval")
        logging.info("examples_processed: %d | %s", examples_processed,
                     iterinfo)
        # This is just to launch an OutOfRangeError when max_steps is reached, to finish the process
        if examples_processed >= (FLAGS.max_batches * FLAGS.batch_size):
            raise ValueError('Time to finish')

    except (tf.errors.OutOfRangeError, ValueError) as e:
      logging.info(
          "Done with batched inference. Now calculating global performance "
          "metrics.")
      # calculate the metrics for the entire epoch
      epoch_info_dict = evl_metrics.get()
      epoch_info_dict["epoch_id"] = global_step_val

      summary_writer.add_summary(summary_val, global_step_val)
      epochinfo = utils.AddEpochSummary(
          summary_writer,
          global_step_val,
          epoch_info_dict,
          summary_scope="Eval")
      logging.info(epochinfo)
      evl_metrics.clear()
    except Exception as e:  # pylint: disable=broad-except
      logging.info("Unexpected exception: " + str(e))
      coord.request_stop(e)

    coord.request_stop()
    coord.join(threads, stop_grace_period_secs=10)

    return global_step_val, video_id_batch_val
コード例 #4
0
ファイル: eval.py プロジェクト: avpronkin/Kaggle_YouTube-8M
def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
                    summary_op, saver, summary_writer, evl_metrics,
                    last_global_step_val):
    """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """

    global_step_val = -1
    with tf.Session() as sess:
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir)
        if latest_checkpoint:
            logging.info("Loading checkpoint for eval: " + latest_checkpoint)
            # Restores from checkpoint
            saver.restore(sess, latest_checkpoint)
            # Assuming model_checkpoint_path looks something like:
            # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
            global_step_val = latest_checkpoint.split("/")[-1].split("-")[-1]
        else:
            logging.info("No checkpoint file found.")
            return global_step_val

        if global_step_val == last_global_step_val:
            logging.info(
                "skip this checkpoint global_step_val=%s "
                "(same as the previous one).", global_step_val)
            return global_step_val

        sess.run([tf.local_variables_initializer()])

        # Start the queue runners.
        fetches = [
            video_id_batch, prediction_batch, label_batch, loss, summary_op
        ]
        coord = tf.train.Coordinator()
        try:
            threads = []
            for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(
                    qr.create_threads(sess,
                                      coord=coord,
                                      daemon=True,
                                      start=True))
            logging.info("enter eval_once loop global_step_val = %s. ",
                         global_step_val)

            evl_metrics.clear()

            examples_processed = 0
            while not coord.should_stop():
                batch_start_time = time.time()
                _, predictions_val, labels_val, loss_val, summary_val = sess.run(
                    fetches)
                seconds_per_batch = time.time() - batch_start_time
                example_per_second = labels_val.shape[0] / seconds_per_batch
                examples_processed += labels_val.shape[0]

                iteration_info_dict = evl_metrics.accumulate(
                    predictions_val, labels_val, loss_val)
                iteration_info_dict["examples_per_second"] = example_per_second

                iterinfo = utils.AddGlobalStepSummary(summary_writer,
                                                      global_step_val,
                                                      iteration_info_dict,
                                                      summary_scope="Eval")
                logging.info("examples_processed: %d | %s", examples_processed,
                             iterinfo)

        except tf.errors.OutOfRangeError as e:
            logging.info(
                "Done with batched inference. Now calculating global performance "
                "metrics.")
            # calculate the metrics for the entire epoch
            epoch_info_dict = evl_metrics.get()
            epoch_info_dict["epoch_id"] = global_step_val

            summary_writer.add_summary(summary_val, global_step_val)
            epochinfo = utils.AddEpochSummary(summary_writer,
                                              global_step_val,
                                              epoch_info_dict,
                                              summary_scope="Eval")
            logging.info(epochinfo)
            evl_metrics.clear()
        except Exception as e:  # pylint: disable=broad-except
            logging.info("Unexpected exception: " + str(e))
            coord.request_stop(e)

        coord.request_stop()
        coord.join(threads, stop_grace_period_secs=10)

        return global_step_val
コード例 #5
0
ファイル: eval_avg.py プロジェクト: idoit/2nd-YouTube8M
def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
                    summary_op, saver, summary_writer, evl_metrics):
    """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """

    global_step_val = -1
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        print('number of trainable variables: %d' %
              len(tf.trainable_variables()))
        print('number of global variables: %d' % len(tf.global_variables()))
        if FLAGS.checkpoint_prefix is None:
            raise IOError(
                ("checkpoint_prefix %s is wrong.") % FLAGS.checkpoint_prefix)
        weights_avg_array = averaging_checkpoint(FLAGS.checkpoint_prefix)

        print('number of parameters after averaging: %d' %
              len(weights_avg_array.keys()))

        print('keys in weights_avg_array: ')
        print(weights_avg_array.keys())

        print('variables: ')
        print(tf.trainable_variables())

        for var in tf.global_variables():
            var_name_drop = var.name.replace(':0', '')  ## drop:0
            if var_name_drop in weights_avg_array.keys():
                sess.run(var.assign(weights_avg_array[var_name_drop]))

        # latest_checkpoint = get_latest_checkpoint()
        # if latest_checkpoint:
        #   logging.info("Loading checkpoint for eval: " + latest_checkpoint)
        #   # Restores from checkpoint
        #   saver.restore(sess, latest_checkpoint)
        #   # Assuming model_checkpoint_path looks something like:
        #   # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
        #   global_step_val = os.path.basename(latest_checkpoint).split("-")[-1]

        #   # Save model
        #   saver.save(sess, os.path.join(FLAGS.train_dir, "inference_model"))
        # else:
        #   logging.info("No checkpoint file found.")
        #   return global_step_val

        # if global_step_val == last_global_step_val:
        #   logging.info("skip this checkpoint global_step_val=%s "
        #                "(same as the previous one).", global_step_val)
        #   return global_step_val

        sess.run([tf.local_variables_initializer()])

        # Start the queue runners.
        fetches = [
            video_id_batch, prediction_batch, label_batch, loss, summary_op
        ]
        coord = tf.train.Coordinator()
        try:
            threads = []
            for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(
                    qr.create_threads(sess,
                                      coord=coord,
                                      daemon=True,
                                      start=True))
            logging.info("enter eval_once loop global_step_val = %s. ",
                         global_step_val)

            evl_metrics.clear()

            examples_processed = 0
            while not coord.should_stop():
                batch_start_time = time.time()
                _, predictions_val, labels_val, loss_val, summary_val = sess.run(
                    fetches)
                seconds_per_batch = time.time() - batch_start_time
                example_per_second = labels_val.shape[0] / seconds_per_batch
                examples_processed += labels_val.shape[0]

                iteration_info_dict = evl_metrics.accumulate(
                    predictions_val, labels_val, loss_val)
                iteration_info_dict["examples_per_second"] = example_per_second

                iterinfo = utils.AddGlobalStepSummary(summary_writer,
                                                      global_step_val,
                                                      iteration_info_dict,
                                                      summary_scope="Eval")
                logging.info("examples_processed: %d | %s", examples_processed,
                             iterinfo)

        except tf.errors.OutOfRangeError as e:
            logging.info(
                "Done with batched inference. Now calculating global performance "
                "metrics.")
            # calculate the metrics for the entire epoch
            epoch_info_dict = evl_metrics.get()
            epoch_info_dict["epoch_id"] = global_step_val

            summary_writer.add_summary(summary_val, global_step_val)
            epochinfo = utils.AddEpochSummary(summary_writer,
                                              global_step_val,
                                              epoch_info_dict,
                                              summary_scope="Eval")
            logging.info(epochinfo)
            evl_metrics.clear()
        except Exception as e:  # pylint: disable=broad-except
            logging.info("Unexpected exception: " + str(e))
            coord.request_stop(e)

        coord.request_stop()
        coord.join(threads, stop_grace_period_secs=10)

        return global_step_val
コード例 #6
0
def evaluation_loop(id_batch, prediction_batch, label_batch, loss, mean_iou,
                    num_examples, summary_op, saver, summary_writer,
                    last_global_step_val):
    """Run the evaluation loop once.

  Args:
    id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """

    global_step_val = -1
    if FLAGS.half_memory:
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    else:
        sess = tf.Session()

    if FLAGS.model_checkpoint_path:
        checkpoint = FLAGS.model_checkpoint_path
    else:
        checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir)
    if checkpoint:
        logging.info("Loading checkpoint for eval: " + checkpoint)
        # Restores from checkpoint
        saver.restore(sess, checkpoint)
        # Assuming model_checkpoint_path looks something like:
        # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
        global_step_val = checkpoint.split("/")[-1].split("-")[-1]
    else:
        logging.info("No checkpoint file found.")
        return global_step_val

    if global_step_val == last_global_step_val:
        logging.info(
            "skip this checkpoint global_step_val=%s "
            "(same as the previous one).", global_step_val)
        return global_step_val

    sess.run([tf.local_variables_initializer()])

    # Start the queue runners.
    fetches = [
        id_batch, prediction_batch, label_batch, loss, mean_iou, num_examples,
        summary_op
    ]
    coord = tf.train.Coordinator()
    try:
        threads = []
        for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
            threads.extend(
                qr.create_threads(sess, coord=coord, daemon=True, start=True))
        logging.info("enter eval_once loop global_step_val = %s. ",
                     global_step_val)

        examples_processed = 0
        total_iou_val = 0.0
        total_loss_val = 0.0

        id_vals, predictions_vals, labels_vals = [], [], []
        file_id = 0

        while not coord.should_stop():
            batch_start_time = time.time()

            id_val, predictions_val, labels_val, loss_val, mean_iou_val, num_examples_val, summary_val = sess.run(
                fetches)

            for i in xrange(num_examples_val):
                id_vals.append(id_val[i])
                predictions_vals.append(predictions_val[i, :, :].tostring())
                labels_vals.append(labels_val[i, :, :].tostring())

            if len(id_vals) >= 16:
                write_to_record(id_vals, predictions_vals, labels_vals,
                                file_id)
                id_vals, predictions_vals, labels_vals = [], [], []
                file_id += 1

            seconds_per_batch = time.time() - batch_start_time
            example_per_second = num_examples_val / seconds_per_batch

            examples_processed += num_examples_val
            total_iou_val += mean_iou_val * num_examples_val
            total_loss_val += loss_val * num_examples_val

            logging.info("examples_processed: %d | mean_iou: %.5f",
                         examples_processed, mean_iou_val)

    except tf.errors.OutOfRangeError as e:
        logging.info(
            "Done with batched inference. Now calculating global performance "
            "metrics.")
        # calculate the metrics for the entire epoch
        epoch_info_dict = {}
        epoch_info_dict["epoch_id"] = global_step_val
        epoch_info_dict["mean_iou"] = total_iou_val / examples_processed
        epoch_info_dict["avg_loss"] = total_loss_val / examples_processed

        summary_writer.add_summary(summary_val, global_step_val)
        epochinfo = utils.AddEpochSummary(summary_writer,
                                          epoch_info_dict,
                                          summary_scope="Eval-train")
        logging.info(epochinfo)

    except Exception as e:  # pylint: disable=broad-except
        logging.info("Unexpected exception: " + str(e))
        coord.request_stop(e)

    coord.request_stop()
    coord.join(threads, stop_grace_period_secs=10)

    sess.close()
    return global_step_val
コード例 #7
0
ファイル: eval.py プロジェクト: ankitshah009/youtube-8m-1
def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
                    summary_op, saver, summary_writer, evl_metrics,
                    last_global_step_val):
    """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """

    global_step_val = -1
    config = tf.ConfigProto(device_count={'GPU': 0})

    with tf.Session(config=config) as sess:
        latest_checkpoint = get_latest_checkpoint()
        if latest_checkpoint:
            logging.info("Loading checkpoint for eval: " + latest_checkpoint)
            # Restores from checkpoint
            saver.restore(sess, latest_checkpoint)
            # Assuming model_checkpoint_path looks something like:
            # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
            global_step_val = os.path.basename(latest_checkpoint).split(
                "-")[-1].split('_')[-1]

            # Save model
            if FLAGS.force_output_model_name:
                saver.save(sess,
                           os.path.join(FLAGS.train_dir, "inference_model"),
                           write_meta_graph=False)

                selected_collections = [
                    'global_step', 'input_batch', 'input_batch_raw', 'labels',
                    'local_variables', 'loss', 'model_variables', 'num_frames',
                    'predictions', 'regularization_losses', 'summaries',
                    'summary_op', 'trainable_variables', 'variables'
                ]
                tf.train.export_meta_graph(
                    filename=os.path.join(FLAGS.train_dir,
                                          "inference_model.meta"),
                    collection_list=selected_collections)

            elif "inference_model" in FLAGS.checkpoint_file:
                if "ensemble" in FLAGS.checkpoint_file:
                    saver.save(
                        sess,
                        os.path.join(
                            FLAGS.train_dir,
                            FLAGS.checkpoint_file.replace(
                                'ensemble',
                                'ensemble_' + str(FLAGS.ensemble_wts).replace(
                                    ',', '').replace(' ', '_').replace(
                                        '.', '').replace('[', '').replace(
                                            ']', ''))))
            else:
                if "avg" not in FLAGS.checkpoint_file:
                    saver.save(
                        sess,
                        os.path.join(
                            FLAGS.train_dir, "inference_model_" +
                            latest_checkpoint.split('-')[-1]))
                else:
                    saver.save(
                        sess,
                        os.path.join(
                            FLAGS.train_dir,
                            "inference_model_" + FLAGS.checkpoint_file))
        else:
            logging.info("No checkpoint file found.")
            return global_step_val

        if global_step_val == last_global_step_val:
            logging.info(
                "skip this checkpoint global_step_val=%s "
                "(same as the previous one).", global_step_val)
            return global_step_val

        if FLAGS.create_meta_only:
            return global_step_val

        sess.run([tf.local_variables_initializer()])

        # Start the queue runners.
        fetches = [
            video_id_batch, prediction_batch, label_batch, loss, summary_op
        ]
        coord = tf.train.Coordinator()
        try:
            threads = []
            for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(
                    qr.create_threads(sess,
                                      coord=coord,
                                      daemon=True,
                                      start=True))
            logging.info("enter eval_once loop global_step_val = %s. ",
                         global_step_val)

            evl_metrics.clear()

            examples_processed = 0
            while not coord.should_stop():
                batch_start_time = time.time()
                _, predictions_val, labels_val, loss_val, summary_val = sess.run(
                    fetches)
                seconds_per_batch = time.time() - batch_start_time
                example_per_second = labels_val.shape[0] / seconds_per_batch
                examples_processed += labels_val.shape[0]

                iteration_info_dict = evl_metrics.accumulate(
                    predictions_val, labels_val, loss_val)
                iteration_info_dict["examples_per_second"] = example_per_second

                iterinfo = utils.AddGlobalStepSummary(summary_writer,
                                                      global_step_val,
                                                      iteration_info_dict,
                                                      summary_scope="Eval")
                logging.info("examples_processed: %d | %s", examples_processed,
                             iterinfo)

        except tf.errors.OutOfRangeError as e:
            logging.info(
                "Done with batched inference. Now calculating global performance "
                "metrics.")
            # calculate the metrics for the entire epoch
            epoch_info_dict = evl_metrics.get()
            epoch_info_dict["epoch_id"] = global_step_val

            summary_writer.add_summary(summary_val, global_step_val)
            epochinfo = utils.AddEpochSummary(summary_writer,
                                              global_step_val,
                                              epoch_info_dict,
                                              summary_scope="Eval")
            logging.info(epochinfo)
            evl_metrics.clear()
        except Exception as e:  # pylint: disable=broad-except
            logging.info("Unexpected exception: " + str(e))
            coord.request_stop(e)

        coord.request_stop()
        coord.join(threads, stop_grace_period_secs=10)

        return global_step_val
コード例 #8
0
def evaluation_loop(model_nums, train_dirs, video_id_batch, prediction_batch,
                    label_batch, loss, summary_op, saver, summary_writer,
                    evl_metrics):

    global_step_val = -1
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        for i in range(model_nums):
            load_vars(sess, train_dirs[i], "model" + str(i))
        # new load
        saver.save(
            sess,
            os.path.join(FLAGS.Ensemble_Models + FLAGS.ensemble_output_path,
                         "inference_model"))

        sess.run([tf.local_variables_initializer()])

        fetches = [video_id_batch, prediction_batch, label_batch, loss]
        coord = tf.train.Coordinator()
        try:
            threads = []
            for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(
                    qr.create_threads(sess,
                                      coord=coord,
                                      daemon=True,
                                      start=True))
            logging.info("enter eval_once loop global_step_val = %s.",
                         global_step_val)
            evl_metrics.clear()

            examples_processed = 0
            while not coord.should_stop():
                batch_start_time = time.time()
                _, predictions_val, labels_val, loss_val = sess.run(fetches)
                seconds_per_batch = time.time() - batch_start_time
                example_per_second = labels_val.shape[0] / seconds_per_batch
                examples_processed += labels_val.shape[0]

                iteration_info_dict = evl_metrics.accumulate(
                    predictions_val, labels_val, loss_val)
                iteration_info_dict["examples_per_second"] = example_per_second

                iterinfo = utils.AddGlobalStepSummary(summary_writer,
                                                      global_step_val,
                                                      iteration_info_dict,
                                                      summary_scope="Eval")
                logging.info("examples_processed: %d | %s", examples_processed,
                             iterinfo)
        except tf.errors.OutOfRangeError as e:
            logging.info(
                "Done with batched inference. Now calculating global performance metrics."
            )
            epoch_info_dict = evl_metrics.get()
            epoch_info_dict["epoch_id"] = global_step_val

            #summary_writer.add_summary(summary_val,global_step_val)
            epochinfo = utils.AddEpochSummary(summary_writer,
                                              global_step_val,
                                              epoch_info_dict,
                                              summary_scope="Eval")
            logging.info(epochinfo)
            evl_metrics.clear()
        except Exception as e:
            logging.info("Unexpected exception:" + str(e))
            coord.request_stop(e)
        coord.request_stop()
        coord.join(threads, stop_grace_period_secs=10)

        return global_step_val
コード例 #9
0
ファイル: eval.py プロジェクト: ZouJoshua/cv
def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
                    summary_op, saver, summary_writer, evl_metrics,
                    last_global_step_val, ema_tensors):
  """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """

  global_step_val = -1
  latest_checkpoint = get_latest_checkpoint()

  with tf.Session() as sess:

    if latest_checkpoint:
      logging.info("Loading checkpoint for eval: " + latest_checkpoint)
      # Restores from checkpoint
      saver.restore(sess, latest_checkpoint)
      # Assuming model_checkpoint_path looks something like:
      # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
      global_step_val = os.path.basename(latest_checkpoint).split("-")[-1]

      if FLAGS.use_EMA:
        assert len(ema_tensors) > 0, "Tensors got lost."
        logging.info("####################")
        logging.info("USING EMA VARIABLES.")
        logging.info("####################")

        reader = pywrap_tensorflow.NewCheckpointReader(latest_checkpoint)
        global_vars = tf.global_variables()

        for stensor in ema_tensors:
          destination_t = [x for x in global_vars if x.name == stensor.replace("/ExponentialMovingAverage:", ":")]
          assert len(destination_t) == 1
          destination_t = destination_t[0]
          ema_source = reader.get_tensor(stensor.split(":")[0])
          # Session to take care of
          destination_t.load(ema_source, session=sess)

      # Save model
      saver.save(sess, os.path.join(FLAGS.train_dir, "inference_model"))
      if FLAGS.build_only:
          logging.info("Inference graph built. Existing now.")
          exit()
    else:
      logging.info("No checkpoint file found.")
      return global_step_val

    if global_step_val == last_global_step_val:
      logging.info("skip this checkpoint global_step_val=%s "
                   "(same as the previous one).", global_step_val)
      return global_step_val

    sess.run([tf.local_variables_initializer()])

    # Start the queue runners.
    fetches = [video_id_batch, prediction_batch, label_batch, loss, summary_op]
    coord = tf.train.Coordinator()
    try:
      threads = []
      for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(qr.create_threads(
            sess, coord=coord, daemon=True,
            start=True))
      logging.info("enter eval_once loop global_step_val = %s. ",
                   global_step_val)

      evl_metrics.clear()

      examples_processed = 0
      while not coord.should_stop():
        batch_start_time = time.time()
        _, predictions_val, labels_val, loss_val, summary_val = sess.run(
            fetches)
        seconds_per_batch = time.time() - batch_start_time
        example_per_second = labels_val.shape[0] / seconds_per_batch
        examples_processed += labels_val.shape[0]

        iteration_info_dict = evl_metrics.accumulate(predictions_val,
                                                     labels_val, loss_val)
        iteration_info_dict["examples_per_second"] = example_per_second

        iterinfo = utils.AddGlobalStepSummary(
            summary_writer,
            global_step_val,
            iteration_info_dict,
            summary_scope="Eval")
        logging.info("examples_processed: %d | %s", examples_processed,
                     iterinfo)

    except tf.errors.OutOfRangeError as e:
      logging.info(
          "Done with batched inference. Now calculating global performance "
          "metrics.")
      # calculate the metrics for the entire epoch
      epoch_info_dict = evl_metrics.get()
      epoch_info_dict["epoch_id"] = global_step_val

      summary_writer.add_summary(summary_val, global_step_val)
      epochinfo = utils.AddEpochSummary(
          summary_writer,
          global_step_val,
          epoch_info_dict,
          summary_scope="Eval")
      logging.info(epochinfo)
      with open ('evallog.txt', 'a') as f: f.write(epochinfo + '\n')
      evl_metrics.clear()
    except Exception as e:  # pylint: disable=broad-except
      logging.info("Unexpected exception: " + str(e))
      coord.request_stop(e)

    coord.request_stop()
    coord.join(threads, stop_grace_period_secs=10)

    return global_step_val
コード例 #10
0
    def eval_loop(self, last_global_step_val, evl_metrics):
        """Run the evaluation loop once.

    Args:
      last_global_step_val: the global step used in the previous evaluation.

    Returns:
      The global_step used in the latest model.
    """
        latest_checkpoint, global_step_val = self.get_checkpoint(
            last_global_step_val)
        logging.info("latest_checkpoint: {}".format(latest_checkpoint))

        if latest_checkpoint is None or global_step_val == last_global_step_val:
            time.sleep(self.wait)
            return last_global_step_val

        config = tf.ConfigProto(allow_soft_placement=True)
        with tf.Session(config=config) as sess:
            logging.info(
                "Loading checkpoint for eval: {}".format(latest_checkpoint))

            # Restores from checkpoint
            self.saver.restore(sess, latest_checkpoint)
            sess.run(tf.local_variables_initializer())

            evl_metrics.clear()

            train_gpu = FLAGS.train_num_gpu
            train_batch_size = FLAGS.train_batch_size
            n_train_files = self.reader.n_train_files
            if train_gpu:
                epoch = ((global_step_val * train_batch_size * train_gpu) /
                         n_train_files)
            else:
                epoch = ((global_step_val * train_batch_size) / n_train_files)

            examples_processed = 0
            while True:
                try:
                    batch_start_time = time.time()

                    fetches = [
                        self.logits, self.labels, self.labels_losses,
                        self.summary_op
                    ]
                    logits_val, labels_val, loss_val, summary_val = sess.run(
                        fetches)
                    seconds_per_batch = time.time() - batch_start_time
                    examples_per_second = self.batch_size / seconds_per_batch
                    examples_processed += self.batch_size

                    iteration_info_dict = evl_metrics.accumulate(
                        logits_val, labels_val, loss_val)
                    iteration_info_dict[
                        "examples_per_second"] = examples_per_second

                    iterinfo = utils.AddGlobalStepSummary(self.summary_writer,
                                                          global_step_val,
                                                          iteration_info_dict,
                                                          summary_scope="Eval")
                    logging.info("examples_processed: %d | %s",
                                 examples_processed, iterinfo)

                except tf.errors.OutOfRangeError as e:
                    logging.info(
                        "Done with batched inference. Now calculating global performance "
                        "metrics.")
                    # calculate the metrics for the entire epoch
                    epoch_info_dict = evl_metrics.get()
                    epoch_info_dict["epoch_id"] = global_step_val

                    self.summary_writer.add_summary(summary_val,
                                                    global_step_val)
                    epochinfo = utils.AddEpochSummary(self.summary_writer,
                                                      global_step_val,
                                                      epoch_info_dict,
                                                      summary_scope="Eval")
                    logging.info(epochinfo)
                    evl_metrics.clear()

                    if FLAGS.stopped_at_n:
                        self.counter += 1
                    break

                except Exception as e:
                    logging.info("Unexpected exception: {}".format(e))
                    sys.exit(0)

            return global_step_val
def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
                    summary_op, saver, summary_writer, evl_metrics,
                    last_global_step_val):
  """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """

  global_step_val = -1
  with tf.Session(config=tf.ConfigProto(log_device_placement=False)) as sess:
#     latest_checkpoint = get_latest_checkpoint()
#     if latest_checkpoint:
#       logging.info("Loading checkpoint for eval: " + latest_checkpoint)
#       # Restores from checkpoint
#       saver.restore(sess, latest_checkpoint)
#       # Assuming model_checkpoint_path looks something like:
#       # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
#       global_step_val = os.path.basename(latest_checkpoint).split("-")[-1]

#       # Save model
#       saver.save(sess, os.path.join(FLAGS.train_dir, "inference_model"))
#     else:
#       logging.info("No checkpoint file found.")
#       return global_step_val

#     if global_step_val == last_global_step_val:
#       logging.info("skip this checkpoint global_step_val=%s "
#                    "(same as the previous one).", global_step_val)
#       return global_step_val

    sess.run([tf.local_variables_initializer()])

    # Start the queue runners.
    fetches = [video_id_batch, prediction_batch, label_batch, loss, summary_op]
    coord = tf.train.Coordinator()
    
    
    # output results
    start_time = time.time()
    video_ids = []
    video_labels = []
    video_features = []
    filenum = 0
    num_examples_processed = 0
    total_num_examples_processed = 0
    
    # output prediction dir
    directory = FLAGS.output_dir
    if directory != '':
      if not os.path.exists(directory):
          os.makedirs(directory)
      else:
          raise IOError("Output path exists! path='" + directory + "'")
    
    try:
      threads = []
      for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(qr.create_threads(
            sess, coord=coord, daemon=True,
            start=True))
      logging.info("enter eval_once loop global_step_val = %s. ",
                   global_step_val)

      evl_metrics.clear()

      examples_processed = 0
      while not coord.should_stop():
        batch_start_time = time.time()
        ids_val, predictions_val, labels_val, loss_val, summary_val = sess.run(
            fetches)
        seconds_per_batch = time.time() - batch_start_time
        example_per_second = labels_val.shape[0] / seconds_per_batch
        examples_processed += labels_val.shape[0]

        iteration_info_dict = evl_metrics.accumulate(predictions_val,
                                                     labels_val, loss_val)
        iteration_info_dict["examples_per_second"] = example_per_second

        iterinfo = utils.AddGlobalStepSummary(
            summary_writer,
            global_step_val,
            iteration_info_dict,
            summary_scope="Eval")
        logging.info("examples_processed: %d | %s", examples_processed,
                     iterinfo)
        
        # save predictions
        if directory != '':
          video_ids.append(ids_val)
          video_labels.append(labels_val)
          video_features.append(predictions_val)
          num_examples_processed += len(ids_val)

          if num_examples_processed >= FLAGS.file_size:
            assert num_examples_processed==FLAGS.file_size, "num_examples_processed should be equal to %d"%FLAGS.file_size
            video_ids = np.concatenate(video_ids, axis=0)
            video_labels = np.concatenate(video_labels, axis=0)
            video_features = np.concatenate(video_features, axis=0)
            write_to_record(video_ids, video_labels, video_features, filenum, num_examples_processed)

            video_ids = []
            video_labels = []
            video_features = []
            filenum += 1
            total_num_examples_processed += num_examples_processed

            now = time.time()
            logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
            num_examples_processed = 0

    except tf.errors.OutOfRangeError as e:
      logging.info(
          "Done with batched inference. Now calculating global performance "
          "metrics.")
      # calculate the metrics for the entire epoch
      epoch_info_dict = evl_metrics.get()
      epoch_info_dict["epoch_id"] = global_step_val

      summary_writer.add_summary(summary_val, global_step_val)
      epochinfo = utils.AddEpochSummary(
          summary_writer,
          global_step_val,
          epoch_info_dict,
          summary_scope="Eval")
      logging.info(epochinfo)
      evl_metrics.clear()
      
      # save prediction
      if directory != '':
        # if ids_val is not None:
        #   video_ids.append(ids_val)
        #   video_labels.append(labels_val)
        #   video_features.append(predictions_val)
        #   num_examples_processed += len(ids_val)

        if 0 < num_examples_processed <= FLAGS.file_size:
          video_ids = np.concatenate(video_ids, axis=0)
          video_labels = np.concatenate(video_labels, axis=0)
          video_features = np.concatenate(video_features, axis=0)
          write_to_record(video_ids, video_labels, video_features, filenum, num_examples_processed)
          total_num_examples_processed += num_examples_processed

          now = time.time()
          logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
          num_examples_processed = 0

        logging.info("Done with inference. %d samples was written to %s" % (total_num_examples_processed, FLAGS.output_dir))
    except Exception as e:  # pylint: disable=broad-except
      logging.info("Unexpected exception: " + str(e))
      coord.request_stop(e)

    coord.request_stop()
    coord.join(threads, stop_grace_period_secs=10)

    return global_step_val
コード例 #12
0
ファイル: eval_loop.py プロジェクト: huan2016/yt8m-1
def evaluation_loop(self, saver, model_ckpt_path):
    global_step_val = model_ckpt_path.split("/")[-1].split("-")[-1]
    evl_metrics = eval_util.EvaluationMetrics(self.model.num_classes,
                                              self.config.top_k)

    # summary_writer = tf.summary.FileWriter(
    # self.train_dir, graph=tf.get_default_graph())
    summary_writer = None

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = 0.9
    video_ids = []
    output_scores = 1  # 1->output score, 2-> output features
    if output_scores == 1:
        model_id = model_ckpt_path.split(
            "/")[-2] + "-" + model_ckpt_path.split("-")[-1]
        # num_insts = 4906660
        # stage = "train"
        num_insts = 1401828
        stage = "validate"
        # num_insts = 700640
        # stage = "test"
        video_ids_pkl_path = "/data/D2DCRC/linchao/YT/scores/{}.{}.pkl".format(
            model_id, stage)
        # video_ids_pkl_path = pkl.load(open("/data/D2DCRC/linchao/YT/{}_vids_dict.pkl".format(stage)))
        # log_path = "/data/D2DCRC/linchao/YT/scores/{}.{}.touch".format(model_id, stage)
        pred_out = h5py.File(
            "/data/D2DCRC/linchao/YT/scores/{}.{}.h5".format(model_id, stage),
            "w")
        pred_dataset = pred_out.create_dataset('scores',
                                               shape=(num_insts,
                                                      self.model.num_classes),
                                               dtype=np.float32)
    elif output_scores == 2:
        output_prefix = "/data/uts700/linchao/yt8m/data/555_netvlad/train"
        tfrecord_cntr = 0
        output_filename = "{}/{}.tfrecord".format(output_prefix,
                                                  tfrecord_cntr / 1200)
        tfrecord_writer = tf.python_io.TFRecordWriter(output_filename)

    with tf.Session(config=sess_config) as sess:
        saver.restore(sess, model_ckpt_path)
        sess.run([tf.local_variables_initializer()])

        # Start the queue runners.
        coord = tf.train.Coordinator()
        try:
            threads = []
            for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(
                    qr.create_threads(sess,
                                      coord=coord,
                                      daemon=True,
                                      start=True))
            logging.info("enter eval_once loop global_step_val = %s. ",
                         global_step_val)

            evl_metrics.clear()

            examples_processed = 0
            while not coord.should_stop():
                batch_start_time = time.time()
                res = sess.run(self.feed_out)
                seconds_per_batch = time.time() - batch_start_time
                example_per_second = res["dense_labels"].shape[
                    0] / seconds_per_batch
                examples_processed += res["dense_labels"].shape[0]
                predictions = res["predictions"]
                video_id = res["video_id"].tolist()
                if output_scores == 1:
                    # for i in xrange(len(video_id)):
                    # pred_dataset[video_ids_pkl_path[video_id[i]], :] = predictions[i]
                    pred_dataset[len(video_ids):len(video_ids) +
                                 len(video_id), :] = predictions
                    video_ids += video_id
                elif output_scores == 2:
                    for i in xrange(len(video_id)):
                        sparse_label = np.array(res["dense_labels"][i],
                                                dtype=np.int32)
                        sparse_label = np.where(sparse_label == 1)[0].tolist()
                        feat = res["feats"][i]
                        print(sparse_label, video_id[i])
                        example = matrix_to_tfexample(feat,
                                                      labels=sparse_label,
                                                      video_id=video_id[i])
                        tfrecord_writer.write(example.SerializeToString())
                        if tfrecord_cntr % 1200 == 0:
                            tfrecord_writer.close()
                            output_filename = "{}/{}.tfrecord".format(
                                output_prefix, tfrecord_cntr / 1200)
                            tfrecord_writer = tf.python_io.TFRecordWriter(
                                output_filename)
                        tfrecord_cntr += 1

                if type(predictions) == list:
                    predictions = eval_util.transform_preds(self, predictions)

                iteration_info_dict = evl_metrics.accumulate(
                    predictions, res["dense_labels"], res["loss"])
                iteration_info_dict["examples_per_second"] = example_per_second

                gap = eval_util.calculate_gap(predictions, res["dense_labels"])
                iterinfo = utils.AddGlobalStepSummary(summary_writer,
                                                      global_step_val,
                                                      iteration_info_dict,
                                                      summary_scope="Eval")
                '''
        p = [str(_) for _ in np.where(res["dense_labels"][0, :] > 0)[0].tolist()]
        print_labels = "+".join(p)
        p = np.argsort(res["predictions"][0, :])[-20:]
        p = np.sort(p).tolist()
        p = [str(_) for _ in p]
        pred_labels = "+".join(p)
        logging.info("vid: %s; gap: %s; labels %s; predictions %s" % (
            res['video_id'][0], gap, print_labels, pred_labels))
        '''
                logging.info("examples_processed: %d | %s | gap: %s",
                             examples_processed, iterinfo, gap)

        except tf.errors.OutOfRangeError as e:
            logging.info(
                "Done with batched inference. Now calculating global performance "
                "metrics.")
            if output_scores == 1:
                pred_out.close()
                # with open(log_path, 'w') as fout:
                # print>>fout , "Done"
                pkl.dump(video_ids, open(video_ids_pkl_path, "w"))
            elif output_scores == 2:
                tfrecord_writer.close()
            else:
                # calculate the metrics for the entire epoch
                epoch_info_dict = evl_metrics.get()
                epoch_info_dict["epoch_id"] = global_step_val
                if summary_writer:
                    summary_writer.add_summary(res["summary"], global_step_val)
                epochinfo = utils.AddEpochSummary(summary_writer,
                                                  global_step_val,
                                                  epoch_info_dict,
                                                  summary_scope="Eval")
                logging.info(epochinfo)
                evl_metrics.clear()
        except Exception as e:  # pylint: disable=broad-except
            logging.info("Unexpected exception: " + str(e))
            coord.request_stop(e)

        coord.request_stop()
        coord.join(threads, stop_grace_period_secs=10)