def test_illegal_args(self): with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"): metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=0) with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"): metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=-10) with self.assertRaisesRegexp(ValueError, "xactly one of"): metric_hook.LoggingMetricHook( tensors=["t"], every_n_iter=5, every_n_secs=5) with self.assertRaisesRegexp(ValueError, "xactly one of"): metric_hook.LoggingMetricHook(tensors=["t"]) with self.assertRaisesRegexp(ValueError, "metric_logger"): metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=5)
def test_log_tensors(self): with tf.Graph().as_default(), tf.compat.v1.Session() as sess: tf.compat.v1.train.get_or_create_global_step() t1 = tf.constant(42.0, name="foo") t2 = tf.constant(43.0, name="bar") train_op = tf.constant(3) hook = metric_hook.LoggingMetricHook(tensors=[t1, t2], at_end=True, metric_logger=self._logger) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access sess.run(tf.compat.v1.global_variables_initializer()) for _ in range(3): mon_sess.run(train_op) self.assertEqual(self._logger.logged_metric, []) hook.end(sess) self.assertEqual(len(self._logger.logged_metric), 2) metric1 = self._logger.logged_metric[0] self.assertRegexpMatches(str(metric1["name"]), "foo") self.assertEqual(metric1["value"], 42.0) self.assertEqual(metric1["unit"], None) self.assertEqual(metric1["global_step"], 0) metric2 = self._logger.logged_metric[1] self.assertRegexpMatches(str(metric2["name"]), "bar") self.assertEqual(metric2["value"], 43.0) self.assertEqual(metric2["unit"], None) self.assertEqual(metric2["global_step"], 0)
def _validate_print_every_n_secs(self, sess, at_end): t = tf.constant(42.0, name="foo") train_op = tf.constant(3) hook = metric_hook.LoggingMetricHook(tensors=[t.name], every_n_secs=1.0, at_end=at_end, metric_logger=self._logger) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access sess.run(tf.compat.v1.global_variables_initializer()) mon_sess.run(train_op) self.assertRegexpMatches(str(self._logger.logged_metric), t.name) # assertNotRegexpMatches is not supported by python 3.1 and later self._logger.logged_metric = [] mon_sess.run(train_op) self.assertEqual(str(self._logger.logged_metric).find(t.name), -1) time.sleep(1.0) self._logger.logged_metric = [] mon_sess.run(train_op) self.assertRegexpMatches(str(self._logger.logged_metric), t.name) self._logger.logged_metric = [] hook.end(sess) if at_end: self.assertRegexpMatches(str(self._logger.logged_metric), t.name) else: # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
def get_logging_metric_hook(benchmark_log_dir=None, tensors_to_log=None, every_n_secs=600, **kwargs): # pylint: disable=unused-argument """Function to get LoggingMetricHook. Args: benchmark_log_dir: `string`, directory path to save the metric log. tensors_to_log: List of tensor names or dictionary mapping labels to tensor names. If not set, log _TENSORS_TO_LOG by default. every_n_secs: `int`, the frequency for logging the metric. Default to every 10 mins. Returns: Returns a ProfilerHook that writes out timelines that can be loaded into profiling tools like chrome://tracing. """ if benchmark_log_dir is None: raise ValueError( "metric_log_dir should be provided to use metric logger") if tensors_to_log is None: tensors_to_log = _TENSORS_TO_LOG return metric_hook.LoggingMetricHook(tensors=tensors_to_log, log_dir=benchmark_log_dir, every_n_secs=every_n_secs)
def _validate_print_every_n_steps(self, sess, at_end): t = tf.constant(42.0, name="foo") train_op = tf.constant(3) hook = metric_hook.LoggingMetricHook( tensors=[t.name], every_n_iter=10, at_end=at_end, metric_logger=self._logger) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access sess.run(tf.compat.v1.global_variables_initializer()) mon_sess.run(train_op) self.assertRegexpMatches(str(self._logger.logged_metric), t.name) for _ in range(3): self._logger.logged_metric = [] for _ in range(9): mon_sess.run(train_op) # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self._logger.logged_metric).find(t.name), -1) mon_sess.run(train_op) self.assertRegexpMatches(str(self._logger.logged_metric), t.name) # Add additional run to verify proper reset when called multiple times. self._logger.logged_metric = [] mon_sess.run(train_op) # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self._logger.logged_metric).find(t.name), -1) self._logger.logged_metric = [] hook.end(sess) if at_end: self.assertRegexpMatches(str(self._logger.logged_metric), t.name) else: # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
def test_global_step_not_found(self): with tf.Graph().as_default(): t = tf.constant(42.0, name="foo") hook = metric_hook.LoggingMetricHook( tensors=[t.name], at_end=True, metric_logger=self._logger) with self.assertRaisesRegexp( RuntimeError, "should be created to use LoggingMetricHook."): hook.begin()
def get_logging_metric_hook(tensors_to_log=None, every_n_secs=600, **kwargs): # pylint: disable=unused-argument """Function to get LoggingMetricHook. Args: tensors_to_log: List of tensor names or dictionary mapping labels to tensor names. If not set, log _TENSORS_TO_LOG by default. every_n_secs: `int`, the frequency for logging the metric. Default to every 10 mins. Returns: Returns a LoggingMetricHook that saves tensor values in a JSON format. """ if tensors_to_log is None: tensors_to_log = _TENSORS_TO_LOG return metric_hook.LoggingMetricHook( tensors=tensors_to_log, metric_logger=logger.get_benchmark_logger(), every_n_secs=every_n_secs)
def get_logging_metric_hook(tensors_to_log=None, every_n_secs=600, **kwargs): # pylint: disable=unused-argument """Function to get LoggingMetricHook. Args: tensors_to_log: List of tensor names or dictionary mapping labels to tensor names. If not set, log _TENSORS_TO_LOG by default. every_n_secs: `int`, the frequency for logging the metric. Default to every 10 mins. Returns: Returns a ProfilerHook that writes out timelines that can be loaded into profiling tools like chrome://tracing. """ if tensors_to_log is None: tensors_to_log = _TENSORS_TO_LOG return metric_hook.LoggingMetricHook( tensors=tensors_to_log, metric_logger=logger.get_benchmark_logger(), every_n_secs=every_n_secs)
def test_print_at_end_only(self): with tf.Graph().as_default(), tf.Session() as sess: tf.train.get_or_create_global_step() t = tf.constant(42.0, name='foo') train_op = tf.constant(3) hook = metric_hook.LoggingMetricHook(tensors=[t.name], at_end=True, metric_logger=self._logger) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) sess.run(tf.global_variables_initializer()) for _ in range(3): mon_sess.run(train_op) self.assertEqual(self._logger.logged_metric, []) hook.end(sess) self.assertEqual(len(self._logger.logged_metric), 1) metric = self._logger.logged_metric[0] self.assertRegexpMatches(metric["name"], "foo") self.assertEqual(metric["value"], 42.0) self.assertEqual(metric["unit"], None) self.assertEqual(metric["global_step"], 0)
def test_illegal_args(self): with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'): metric_hook.LoggingMetricHook(tensors=['t'], every_n_iter=0) with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'): metric_hook.LoggingMetricHook(tensors=['t'], every_n_iter=-10) with self.assertRaisesRegexp(ValueError, 'xactly one of'): metric_hook.LoggingMetricHook(tensors=['t'], every_n_iter=5, every_n_secs=5) with self.assertRaisesRegexp(ValueError, 'xactly one of'): metric_hook.LoggingMetricHook(tensors=['t']) with self.assertRaisesRegexp(ValueError, 'log_dir and metric_logger'): metric_hook.LoggingMetricHook(tensors=['t'], every_n_iter=5) with self.assertRaisesRegexp(ValueError, 'log_dir and metric_logger'): metric_hook.LoggingMetricHook(tensors=['t'], every_n_iter=5, log_dir=self._log_dir, metric_logger=self._logger)