示例#1
0
 def test_raise_in_none_secs_and_steps(self):
     with self.assertRaises(ValueError):
         hooks.ExamplesPerSecondHook(
             batch_size=256,
             every_n_steps=None,
             every_n_secs=None,
             metric_logger=self._logger)
  def _validate_log_every_n_steps(self, sess, every_n_steps, warm_steps):
    hook = hooks.ExamplesPerSecondHook(
        batch_size=256,
        every_n_steps=every_n_steps,
        warm_steps=warm_steps)
    hook.begin()
    mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
    sess.run(tf.global_variables_initializer())

    self.logged_message = ''
    for _ in range(every_n_steps):
      mon_sess.run(self.train_op)
      self.assertEqual(str(self.logged_message).find('exp/sec'), -1)

    mon_sess.run(self.train_op)
    global_step_val = sess.run(self.global_step)
    # assertNotRegexpMatches is not supported by python 3.1 and later
    if global_step_val > warm_steps:
      self.assertRegexpMatches(str(self.logged_message), 'exp/sec')
    else:
      self.assertEqual(str(self.logged_message).find('exp/sec'), -1)

    # Add additional run to verify proper reset when called multiple times.
    self.logged_message = ''
    mon_sess.run(self.train_op)
    global_step_val = sess.run(self.global_step)
    if every_n_steps == 1 and global_step_val > warm_steps:
      self.assertRegexpMatches(str(self.logged_message), 'exp/sec')
    else:
      self.assertEqual(str(self.logged_message).find('exp/sec'), -1)

    hook.end(sess)
示例#3
0
    def _validate_log_every_n_secs(self, sess, every_n_secs):
        hook = hooks.ExamplesPerSecondHook(batch_size=256,
                                           every_n_steps=None,
                                           every_n_secs=every_n_secs)
        hook.begin()
        mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
        sess.run(tf.global_variables_initializer())

        self.logged_message = ''
        mon_sess.run(self.train_op)
        self.assertEqual(str(self.logged_message).find('exp/sec'), -1)
        time.sleep(every_n_secs)

        self.logged_message = ''
        mon_sess.run(self.train_op)
        self.assertRegexpMatches(str(self.logged_message), 'exp/sec')

        hook.end(sess)
示例#4
0
    def _validate_log_every_n_secs(self, every_n_secs):
        hook = hooks.ExamplesPerSecondHook(batch_size=256,
                                           every_n_steps=None,
                                           every_n_secs=every_n_secs,
                                           metric_logger=self._logger)

        with tf.train.MonitoredSession(tf.train.ChiefSessionCreator(),
                                       [hook]) as mon_sess:
            # Explicitly run global_step after train_op to get the accurate
            # global_step value
            mon_sess.run(self.train_op)
            mon_sess.run(self.global_step)
            # Nothing should be in the list yet
            self.assertFalse(self._logger.logged_metric)
            time.sleep(every_n_secs)

            mon_sess.run(self.train_op)
            mon_sess.run(self.global_step)
            self._assert_metrics()
def get_examples_per_second_hook(every_n_steps=100,
                                 batch_size=128,
                                 warm_steps=5,
                                 **kwargs):  # pylint: disable=unused-argument
    """Function to get ExamplesPerSecondHook.

  Args:
    every_n_steps: `int`, print current and average examples per second every
      N steps.
    batch_size: `int`, total batch size used to calculate examples/second from
      global time.
    warm_steps: skip this number of steps before logging and running average.
    **kwargs: a dictionary of arguments to ExamplesPerSecondHook.

  Returns:
    Returns a ProfilerHook that writes out timelines that can be loaded into
    profiling tools like chrome://tracing.
  """
    return hooks.ExamplesPerSecondHook(every_n_steps=every_n_steps,
                                       batch_size=batch_size,
                                       warm_steps=warm_steps)
示例#6
0
    def _validate_log_every_n_steps(self, every_n_steps, warm_steps):
        hook = hooks.ExamplesPerSecondHook(
            batch_size=256,
            every_n_steps=every_n_steps,
            warm_steps=warm_steps,
            metric_logger=self._logger)

        with tf.train.MonitoredSession(
                tf.train.ChiefSessionCreator(), [hook]) as mon_sess:
            for _ in range(every_n_steps):
                # Explicitly run global_step after train_op to get the accurate
                # global_step value
                mon_sess.run(self.train_op)
                mon_sess.run(self.global_step)
                # Nothing should be in the list yet
                self.assertFalse(self._logger.logged_metric)

            mon_sess.run(self.train_op)
            global_step_val = mon_sess.run(self.global_step)

            if global_step_val > warm_steps:
                self._assert_metrics()
            else:
                # Nothing should be in the list yet
                self.assertFalse(self._logger.logged_metric)

            # Add additional run to verify proper reset when called multiple times.
            prev_log_len = len(self._logger.logged_metric)
            mon_sess.run(self.train_op)
            global_step_val = mon_sess.run(self.global_step)

            if every_n_steps == 1 and global_step_val > warm_steps:
                # Each time, we log two additional metrics. Did exactly 2 get added?
                self.assertEqual(len(self._logger.logged_metric), prev_log_len + 2)
            else:
                # No change in the size of the metric list.
                self.assertEqual(len(self._logger.logged_metric), prev_log_len)
 def test_raise_in_both_secs_and_steps(self):
   with self.assertRaises(ValueError):
     hooks.ExamplesPerSecondHook(
         batch_size=256,
         every_n_steps=10,
         every_n_secs=20)