def test_illegal_args(self): with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"): metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=0) with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"): metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=-10) with self.assertRaisesRegexp(ValueError, "xactly one of"): metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=5, every_n_secs=5) with self.assertRaisesRegexp(ValueError, "xactly one of"): metric_hook.LoggingMetricHook(tensors=["t"]) with self.assertRaisesRegexp(ValueError, "metric_logger"): metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=5)
def test_log_tensors(self): with tf.Graph().as_default(), tf.compat.v1.Session() as sess: tf.compat.v1.train.get_or_create_global_step() t1 = tf.constant(42.0, name="foo") t2 = tf.constant(43.0, name="bar") train_op = tf.constant(3) hook = metric_hook.LoggingMetricHook(tensors=[t1, t2], at_end=True, metric_logger=self._logger) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access sess.run(tf.compat.v1.global_variables_initializer()) for _ in range(3): mon_sess.run(train_op) self.assertEqual(self._logger.logged_metric, []) hook.end(sess) self.assertEqual(len(self._logger.logged_metric), 2) metric1 = self._logger.logged_metric[0] self.assertRegexpMatches(str(metric1["name"]), "foo") self.assertEqual(metric1["value"], 42.0) self.assertEqual(metric1["unit"], None) self.assertEqual(metric1["global_step"], 0) metric2 = self._logger.logged_metric[1] self.assertRegexpMatches(str(metric2["name"]), "bar") self.assertEqual(metric2["value"], 43.0) self.assertEqual(metric2["unit"], None) self.assertEqual(metric2["global_step"], 0)
def _validate_print_every_n_secs(self, sess, at_end): t = tf.constant(42.0, name="foo") train_op = tf.constant(3) hook = metric_hook.LoggingMetricHook(tensors=[t.name], every_n_secs=1.0, at_end=at_end, metric_logger=self._logger) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access sess.run(tf.compat.v1.global_variables_initializer()) mon_sess.run(train_op) self.assertRegexpMatches(str(self._logger.logged_metric), t.name) # assertNotRegexpMatches is not supported by python 3.1 and later self._logger.logged_metric = [] mon_sess.run(train_op) self.assertEqual(str(self._logger.logged_metric).find(t.name), -1) time.sleep(1.0) self._logger.logged_metric = [] mon_sess.run(train_op) self.assertRegexpMatches(str(self._logger.logged_metric), t.name) self._logger.logged_metric = [] hook.end(sess) if at_end: self.assertRegexpMatches(str(self._logger.logged_metric), t.name) else: # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
def test_global_step_not_found(self): with tf.Graph().as_default(): t = tf.constant(42.0, name="foo") hook = metric_hook.LoggingMetricHook(tensors=[t.name], at_end=True, metric_logger=self._logger) with self.assertRaisesRegexp( RuntimeError, "should be created to use LoggingMetricHook."): hook.begin()
def get_logging_metric_hook(tensors_to_log=None, every_n_secs=600, **kwargs): # pylint: disable=unused-argument """Function to get LoggingMetricHook. Args: tensors_to_log: List of tensor names or dictionary mapping labels to tensor names. If not set, log _TENSORS_TO_LOG by default. every_n_secs: `int`, the frequency for logging the metric. Default to every 10 mins. **kwargs: a dictionary of arguments. Returns: Returns a LoggingMetricHook that saves tensor values in a JSON format. """ if tensors_to_log is None: tensors_to_log = _TENSORS_TO_LOG return metric_hook.LoggingMetricHook( tensors=tensors_to_log, metric_logger=logger.get_benchmark_logger(), every_n_secs=every_n_secs)
def _validate_print_every_n_steps(self, sess, at_end): t = tf.constant(42.0, name="foo") train_op = tf.constant(3) hook = metric_hook.LoggingMetricHook(tensors=[t.name], every_n_iter=10, at_end=at_end, metric_logger=self._logger) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access sess.run(tf.compat.v1.global_variables_initializer()) mon_sess.run(train_op) self.assertRegexpMatches(str(self._logger.logged_metric), t.name) for _ in range(3): self._logger.logged_metric = [] for _ in range(9): mon_sess.run(train_op) # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual( str(self._logger.logged_metric).find(t.name), -1) mon_sess.run(train_op) self.assertRegexpMatches(str(self._logger.logged_metric), t.name) # Add additional run to verify proper reset when called multiple times. self._logger.logged_metric = [] mon_sess.run(train_op) # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self._logger.logged_metric).find(t.name), -1) self._logger.logged_metric = [] hook.end(sess) if at_end: self.assertRegexpMatches(str(self._logger.logged_metric), t.name) else: # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)