Beispiel #1
0
 def test_raise_in_none_secs_and_steps(self):
   with self.assertRaises(ValueError):
     hooks.ExamplesPerSecondHook(
         batch_size=256,
         every_n_steps=None,
         every_n_secs=None,
         metric_logger=self._logger)
Beispiel #2
0
  def _validate_log_every_n_secs(self, every_n_secs):
    hook = hooks.ExamplesPerSecondHook(
        batch_size=256,
        every_n_steps=None,
        every_n_secs=every_n_secs,
        metric_logger=self._logger)

    with tf.compat.v1.train.MonitoredSession(
        tf.compat.v1.train.ChiefSessionCreator(), [hook]) as mon_sess:
      # Explicitly run global_step after train_op to get the accurate
      # global_step value
      mon_sess.run(self.train_op)
      mon_sess.run(self.global_step)
      # Nothing should be in the list yet
      self.assertFalse(self._logger.logged_metric)
      time.sleep(every_n_secs)

      mon_sess.run(self.train_op)
      mon_sess.run(self.global_step)
      self._assert_metrics()
def get_examples_per_second_hook(every_n_steps=100,
                                 batch_size=128,
                                 warm_steps=5,
                                 **kwargs):  # pylint: disable=unused-argument
  """Function to get ExamplesPerSecondHook.

  Args:
    every_n_steps: `int`, print current and average examples per second every
      N steps.
    batch_size: `int`, total batch size used to calculate examples/second from
      global time.
    warm_steps: skip this number of steps before logging and running average.
    **kwargs: a dictionary of arguments to ExamplesPerSecondHook.

  Returns:
    Returns a ProfilerHook that writes out timelines that can be loaded into
    profiling tools like chrome://tracing.
  """
  return hooks.ExamplesPerSecondHook(
      batch_size=batch_size, every_n_steps=every_n_steps,
      warm_steps=warm_steps, metric_logger=logger.get_benchmark_logger())
Beispiel #4
0
  def _validate_log_every_n_steps(self, every_n_steps, warm_steps):
    hook = hooks.ExamplesPerSecondHook(
        batch_size=256,
        every_n_steps=every_n_steps,
        warm_steps=warm_steps,
        metric_logger=self._logger)

    with tf.compat.v1.train.MonitoredSession(
        tf.compat.v1.train.ChiefSessionCreator(), [hook]) as mon_sess:
      for _ in range(every_n_steps):
        # Explicitly run global_step after train_op to get the accurate
        # global_step value
        mon_sess.run(self.train_op)
        mon_sess.run(self.global_step)
        # Nothing should be in the list yet
        self.assertFalse(self._logger.logged_metric)

      mon_sess.run(self.train_op)
      global_step_val = mon_sess.run(self.global_step)

      if global_step_val > warm_steps:
        self._assert_metrics()
      else:
        # Nothing should be in the list yet
        self.assertFalse(self._logger.logged_metric)

      # Add additional run to verify proper reset when called multiple times.
      prev_log_len = len(self._logger.logged_metric)
      mon_sess.run(self.train_op)
      global_step_val = mon_sess.run(self.global_step)

      if every_n_steps == 1 and global_step_val > warm_steps:
        # Each time, we log two additional metrics. Did exactly 2 get added?
        self.assertEqual(len(self._logger.logged_metric), prev_log_len + 2)
      else:
        # No change in the size of the metric list.
        self.assertEqual(len(self._logger.logged_metric), prev_log_len)