def test_raise_in_none_secs_and_steps(self): with self.assertRaises(ValueError): hooks.ExamplesPerSecondHook( batch_size=256, every_n_steps=None, every_n_secs=None, metric_logger=self._logger)
def _validate_log_every_n_secs(self, every_n_secs): hook = hooks.ExamplesPerSecondHook( batch_size=256, every_n_steps=None, every_n_secs=every_n_secs, metric_logger=self._logger) with tf.compat.v1.train.MonitoredSession( tf.compat.v1.train.ChiefSessionCreator(), [hook]) as mon_sess: # Explicitly run global_step after train_op to get the accurate # global_step value mon_sess.run(self.train_op) mon_sess.run(self.global_step) # Nothing should be in the list yet self.assertFalse(self._logger.logged_metric) time.sleep(every_n_secs) mon_sess.run(self.train_op) mon_sess.run(self.global_step) self._assert_metrics()
def get_examples_per_second_hook(every_n_steps=100, batch_size=128, warm_steps=5, **kwargs): # pylint: disable=unused-argument """Function to get ExamplesPerSecondHook. Args: every_n_steps: `int`, print current and average examples per second every N steps. batch_size: `int`, total batch size used to calculate examples/second from global time. warm_steps: skip this number of steps before logging and running average. **kwargs: a dictionary of arguments to ExamplesPerSecondHook. Returns: Returns a ProfilerHook that writes out timelines that can be loaded into profiling tools like chrome://tracing. """ return hooks.ExamplesPerSecondHook( batch_size=batch_size, every_n_steps=every_n_steps, warm_steps=warm_steps, metric_logger=logger.get_benchmark_logger())
def _validate_log_every_n_steps(self, every_n_steps, warm_steps): hook = hooks.ExamplesPerSecondHook( batch_size=256, every_n_steps=every_n_steps, warm_steps=warm_steps, metric_logger=self._logger) with tf.compat.v1.train.MonitoredSession( tf.compat.v1.train.ChiefSessionCreator(), [hook]) as mon_sess: for _ in range(every_n_steps): # Explicitly run global_step after train_op to get the accurate # global_step value mon_sess.run(self.train_op) mon_sess.run(self.global_step) # Nothing should be in the list yet self.assertFalse(self._logger.logged_metric) mon_sess.run(self.train_op) global_step_val = mon_sess.run(self.global_step) if global_step_val > warm_steps: self._assert_metrics() else: # Nothing should be in the list yet self.assertFalse(self._logger.logged_metric) # Add additional run to verify proper reset when called multiple times. prev_log_len = len(self._logger.logged_metric) mon_sess.run(self.train_op) global_step_val = mon_sess.run(self.global_step) if every_n_steps == 1 and global_step_val > warm_steps: # Each time, we log two additional metrics. Did exactly 2 get added? self.assertEqual(len(self._logger.logged_metric), prev_log_len + 2) else: # No change in the size of the metric list. self.assertEqual(len(self._logger.logged_metric), prev_log_len)