Ejemplo n.º 1
0
    def _validate_log_every_n_steps(self, sess, every_n_steps, warm_steps):
        hook = hooks.ExamplesPerSecondHook(batch_size=256,
                                           every_n_steps=every_n_steps,
                                           warm_steps=warm_steps)
        hook.begin()
        mon_sess = monitored_session._HookedSession(sess, [hook])
        sess.run(tf.global_variables_initializer())

        self.logged_message = ''
        for _ in range(every_n_steps):
            mon_sess.run(self.train_op)
            self.assertEqual(str(self.logged_message).find('exp/sec'), -1)

        mon_sess.run(self.train_op)
        global_step_val = sess.run(self.global_step)
        # assertNotRegexpMatches is not supported by python 3.1 and later
        if global_step_val > warm_steps:
            self.assertRegexpMatches(str(self.logged_message), 'exp/sec')
        else:
            self.assertEqual(str(self.logged_message).find('exp/sec'), -1)

        # Add additional run to verify proper reset when called multiple times.
        self.logged_message = ''
        mon_sess.run(self.train_op)
        global_step_val = sess.run(self.global_step)
        if every_n_steps == 1 and global_step_val > warm_steps:
            self.assertRegexpMatches(str(self.logged_message), 'exp/sec')
        else:
            self.assertEqual(str(self.logged_message).find('exp/sec'), -1)

        hook.end(sess)
Ejemplo n.º 2
0
    def _validate_log_every_n_secs(self, sess, every_n_secs):
        hook = hooks.ExamplesPerSecondHook(batch_size=256,
                                           every_n_steps=None,
                                           every_n_secs=every_n_secs)
        hook.begin()
        mon_sess = monitored_session._HookedSession(sess, [hook])
        sess.run(tf.global_variables_initializer())

        self.logged_message = ''
        mon_sess.run(self.train_op)
        self.assertEqual(str(self.logged_message).find('exp/sec'), -1)
        time.sleep(every_n_secs)

        self.logged_message = ''
        mon_sess.run(self.train_op)
        self.assertRegexpMatches(str(self.logged_message), 'exp/sec')

        hook.end(sess)
def get_examples_per_second_hook(every_n_steps=100,
                                 batch_size=128,
                                 warm_steps=5,
                                 **kwargs):  # pylint: disable=unused-argument
    """Function to get ExamplesPerSecondHook.

  Args:
    every_n_steps: `int`, print current and average examples per second every
      N steps.
    batch_size: `int`, total batch size used to calculate examples/second from
      global time.
    warm_steps: skip this number of steps before logging and running average.
    kwargs: a dictionary of arguments to ExamplesPerSecondHook.

  Returns:
    Returns a ProfilerHook that writes out timelines that can be loaded into
    profiling tools like chrome://tracing.
  """
    return hooks.ExamplesPerSecondHook(
        every_n_steps=1,  #every_n_steps,
        batch_size=batch_size,
        warm_steps=warm_steps)
Ejemplo n.º 4
0
 def test_raise_in_none_secs_and_steps(self):
     with self.assertRaises(ValueError):
         hooks.ExamplesPerSecondHook(batch_size=256,
                                     every_n_steps=None,
                                     every_n_secs=None)