def _validate_log_every_n_steps(self, sess, every_n_steps, warm_steps): hook = hooks.ExamplesPerSecondHook(batch_size=256, every_n_steps=every_n_steps, warm_steps=warm_steps) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) sess.run(tf.global_variables_initializer()) self.logged_message = '' for _ in range(every_n_steps): mon_sess.run(self.train_op) self.assertEqual(str(self.logged_message).find('exp/sec'), -1) mon_sess.run(self.train_op) global_step_val = sess.run(self.global_step) # assertNotRegexpMatches is not supported by python 3.1 and later if global_step_val > warm_steps: self.assertRegexpMatches(str(self.logged_message), 'exp/sec') else: self.assertEqual(str(self.logged_message).find('exp/sec'), -1) # Add additional run to verify proper reset when called multiple times. self.logged_message = '' mon_sess.run(self.train_op) global_step_val = sess.run(self.global_step) if every_n_steps == 1 and global_step_val > warm_steps: self.assertRegexpMatches(str(self.logged_message), 'exp/sec') else: self.assertEqual(str(self.logged_message).find('exp/sec'), -1) hook.end(sess)
def _validate_log_every_n_secs(self, sess, every_n_secs): hook = hooks.ExamplesPerSecondHook(batch_size=256, every_n_steps=None, every_n_secs=every_n_secs) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) sess.run(tf.global_variables_initializer()) self.logged_message = '' mon_sess.run(self.train_op) self.assertEqual(str(self.logged_message).find('exp/sec'), -1) time.sleep(every_n_secs) self.logged_message = '' mon_sess.run(self.train_op) self.assertRegexpMatches(str(self.logged_message), 'exp/sec') hook.end(sess)
def get_examples_per_second_hook(every_n_steps=100, batch_size=128, warm_steps=5, **kwargs): # pylint: disable=unused-argument """Function to get ExamplesPerSecondHook. Args: every_n_steps: `int`, print current and average examples per second every N steps. batch_size: `int`, total batch size used to calculate examples/second from global time. warm_steps: skip this number of steps before logging and running average. **kwargs: a dictionary of arguments to ExamplesPerSecondHook. Returns: Returns a ProfilerHook that writes out timelines that can be loaded into profiling tools like chrome://tracing. """ return hooks.ExamplesPerSecondHook(every_n_steps=every_n_steps, batch_size=batch_size, warm_steps=warm_steps)
def test_raise_in_none_secs_and_steps(self): with self.assertRaises(ValueError): hooks.ExamplesPerSecondHook(batch_size=256, every_n_steps=None, every_n_secs=None)