Exemple #1
0
 def test_illegal_args(self):
   with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"):
     metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=0)
   with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"):
     metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=-10)
   with self.assertRaisesRegexp(ValueError, "xactly one of"):
     metric_hook.LoggingMetricHook(
         tensors=["t"], every_n_iter=5, every_n_secs=5)
   with self.assertRaisesRegexp(ValueError, "xactly one of"):
     metric_hook.LoggingMetricHook(tensors=["t"])
   with self.assertRaisesRegexp(ValueError, "metric_logger"):
     metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=5)
Exemple #2
0
  def test_log_tensors(self):
    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
      tf.compat.v1.train.get_or_create_global_step()
      t1 = tf.constant(42.0, name="foo")
      t2 = tf.constant(43.0, name="bar")
      train_op = tf.constant(3)
      hook = metric_hook.LoggingMetricHook(
          tensors=[t1, t2], at_end=True, metric_logger=self._logger)
      hook.begin()
      mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
      sess.run(tf.compat.v1.global_variables_initializer())

      for _ in range(3):
        mon_sess.run(train_op)
        self.assertEqual(self._logger.logged_metric, [])

      hook.end(sess)
      self.assertEqual(len(self._logger.logged_metric), 2)
      metric1 = self._logger.logged_metric[0]
      self.assertRegexpMatches(str(metric1["name"]), "foo")
      self.assertEqual(metric1["value"], 42.0)
      self.assertEqual(metric1["unit"], None)
      self.assertEqual(metric1["global_step"], 0)

      metric2 = self._logger.logged_metric[1]
      self.assertRegexpMatches(str(metric2["name"]), "bar")
      self.assertEqual(metric2["value"], 43.0)
      self.assertEqual(metric2["unit"], None)
      self.assertEqual(metric2["global_step"], 0)
Exemple #3
0
  def _validate_print_every_n_secs(self, sess, at_end):
    t = tf.constant(42.0, name="foo")
    train_op = tf.constant(3)

    hook = metric_hook.LoggingMetricHook(
        tensors=[t.name], every_n_secs=1.0, at_end=at_end,
        metric_logger=self._logger)
    hook.begin()
    mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
    sess.run(tf.compat.v1.global_variables_initializer())

    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

    # assertNotRegexpMatches is not supported by python 3.1 and later
    self._logger.logged_metric = []
    mon_sess.run(train_op)
    self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
    time.sleep(1.0)

    self._logger.logged_metric = []
    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

    self._logger.logged_metric = []
    hook.end(sess)
    if at_end:
      self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
    else:
      # assertNotRegexpMatches is not supported by python 3.1 and later
      self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
Exemple #4
0
  def _validate_print_every_n_steps(self, sess, at_end):
    t = tf.constant(42.0, name="foo")

    train_op = tf.constant(3)
    hook = metric_hook.LoggingMetricHook(
        tensors=[t.name], every_n_iter=10, at_end=at_end,
        metric_logger=self._logger)
    hook.begin()
    mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
    sess.run(tf.compat.v1.global_variables_initializer())
    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
    for _ in range(3):
      self._logger.logged_metric = []
      for _ in range(9):
        mon_sess.run(train_op)
        # assertNotRegexpMatches is not supported by python 3.1 and later
        self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
      mon_sess.run(train_op)
      self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

    # Add additional run to verify proper reset when called multiple times.
    self._logger.logged_metric = []
    mon_sess.run(train_op)
    # assertNotRegexpMatches is not supported by python 3.1 and later
    self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)

    self._logger.logged_metric = []
    hook.end(sess)
    if at_end:
      self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
    else:
      # assertNotRegexpMatches is not supported by python 3.1 and later
      self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
Exemple #5
0
  def test_global_step_not_found(self):
    with tf.Graph().as_default():
      t = tf.constant(42.0, name="foo")
      hook = metric_hook.LoggingMetricHook(
          tensors=[t.name], at_end=True, metric_logger=self._logger)

      with self.assertRaisesRegexp(
          RuntimeError, "should be created to use LoggingMetricHook."):
        hook.begin()
def get_logging_metric_hook(tensors_to_log=None,
                            every_n_secs=600,
                            **kwargs):  # pylint: disable=unused-argument
  """Function to get LoggingMetricHook.

  Args:
    tensors_to_log: List of tensor names or dictionary mapping labels to tensor
      names. If not set, log _TENSORS_TO_LOG by default.
    every_n_secs: `int`, the frequency for logging the metric. Default to every
      10 mins.
    **kwargs: a dictionary of arguments.

  Returns:
    Returns a LoggingMetricHook that saves tensor values in a JSON format.
  """
  if tensors_to_log is None:
    tensors_to_log = _TENSORS_TO_LOG
  return metric_hook.LoggingMetricHook(
      tensors=tensors_to_log,
      metric_logger=logger.get_benchmark_logger(),
      every_n_secs=every_n_secs)