def test_illegal_args(self):
   with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"):
     metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=0)
   with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"):
     metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=-10)
   with self.assertRaisesRegexp(ValueError, "xactly one of"):
     metric_hook.LoggingMetricHook(
         tensors=["t"], every_n_iter=5, every_n_secs=5)
   with self.assertRaisesRegexp(ValueError, "xactly one of"):
     metric_hook.LoggingMetricHook(tensors=["t"])
   with self.assertRaisesRegexp(ValueError, "metric_logger"):
     metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=5)
Exemplo n.º 2
0
    def test_log_tensors(self):
        with tf.Graph().as_default(), tf.Session() as sess:
            tf.train.get_or_create_global_step()
            t1 = tf.constant(42.0, name="foo")
            t2 = tf.constant(43.0, name="bar")
            train_op = tf.constant(3)
            hook = metric_hook.LoggingMetricHook(tensors=[t1, t2],
                                                 at_end=True,
                                                 metric_logger=self._logger)
            hook.begin()
            mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
            sess.run(tf.global_variables_initializer())

            for _ in range(3):
                mon_sess.run(train_op)
                self.assertEqual(self._logger.logged_metric, [])

            hook.end(sess)
            self.assertEqual(len(self._logger.logged_metric), 2)
            metric1 = self._logger.logged_metric[0]
            self.assertRegex(str(metric1["name"]), "foo")
            self.assertEqual(metric1["value"], 42.0)
            self.assertEqual(metric1["unit"], None)
            self.assertEqual(metric1["global_step"], 0)

            metric2 = self._logger.logged_metric[1]
            self.assertRegex(str(metric2["name"]), "bar")
            self.assertEqual(metric2["value"], 43.0)
            self.assertEqual(metric2["unit"], None)
            self.assertEqual(metric2["global_step"], 0)
Exemplo n.º 3
0
    def _validate_print_every_n_secs(self, sess, at_end):
        t = tf.constant(42.0, name="foo")
        train_op = tf.constant(3)

        hook = metric_hook.LoggingMetricHook(tensors=[t.name],
                                             every_n_secs=1.0,
                                             at_end=at_end,
                                             metric_logger=self._logger)
        hook.begin()
        mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
        sess.run(tf.global_variables_initializer())

        mon_sess.run(train_op)
        self.assertRegex(str(self._logger.logged_metric), t.name)

        # assertNotRegexpMatches is not supported by python 3.1 and later
        self._logger.logged_metric = []
        mon_sess.run(train_op)
        self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
        time.sleep(1.0)

        self._logger.logged_metric = []
        mon_sess.run(train_op)
        self.assertRegex(str(self._logger.logged_metric), t.name)

        self._logger.logged_metric = []
        hook.end(sess)
        if at_end:
            self.assertRegex(str(self._logger.logged_metric), t.name)
        else:
            # assertNotRegexpMatches is not supported by python 3.1 and later
            self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
Exemplo n.º 4
0
def get_logging_metric_hook(benchmark_log_dir=None,
                            tensors_to_log=None,
                            every_n_secs=600,
                            **kwargs):  # pylint: disable=unused-argument
    """Function to get LoggingMetricHook.

    Args:
      benchmark_log_dir: `string`, directory path to save the metric log.
      tensors_to_log: List of tensor names or dictionary mapping labels to tensor
        names. If not set, log _TENSORS_TO_LOG by default.
      every_n_secs: `int`, the frequency for logging the metric. Default to every
        10 mins.

    Returns:
      Returns a ProfilerHook that writes out timelines that can be loaded into
      profiling tools like chrome://tracing.
    """
    if benchmark_log_dir is None:
        raise ValueError(
            "metric_log_dir should be provided to use metric logger")
    if tensors_to_log is None:
        tensors_to_log = _TENSORS_TO_LOG
    return metric_hook.LoggingMetricHook(tensors=tensors_to_log,
                                         log_dir=benchmark_log_dir,
                                         every_n_secs=every_n_secs)
  def _validate_print_every_n_steps(self, sess, at_end):
    t = tf.constant(42.0, name="foo")

    train_op = tf.constant(3)
    hook = metric_hook.LoggingMetricHook(
        tensors=[t.name], every_n_iter=10, at_end=at_end,
        metric_logger=self._logger)
    hook.begin()
    mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
    sess.run(tf.global_variables_initializer())
    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
    for _ in range(3):
      self._logger.logged_metric = []
      for _ in range(9):
        mon_sess.run(train_op)
        # assertNotRegexpMatches is not supported by python 3.1 and later
        self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
      mon_sess.run(train_op)
      self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

    # Add additional run to verify proper reset when called multiple times.
    self._logger.logged_metric = []
    mon_sess.run(train_op)
    # assertNotRegexpMatches is not supported by python 3.1 and later
    self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)

    self._logger.logged_metric = []
    hook.end(sess)
    if at_end:
      self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
    else:
      # assertNotRegexpMatches is not supported by python 3.1 and later
      self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
  def test_global_step_not_found(self):
    with tf.Graph().as_default():
      t = tf.constant(42.0, name="foo")
      hook = metric_hook.LoggingMetricHook(
          tensors=[t.name], at_end=True, metric_logger=self._logger)

      with self.assertRaisesRegexp(
          RuntimeError, "should be created to use LoggingMetricHook."):
        hook.begin()
Exemplo n.º 7
0
def get_logging_metric_hook(tensors_to_log=None, every_n_secs=600, **kwargs):  # pylint: disable=unused-argument
    """Function to get LoggingMetricHook.

  Args:
    tensors_to_log: List of tensor names or dictionary mapping labels to tensor
      names. If not set, log _TENSORS_TO_LOG by default.
    every_n_secs: `int`, the frequency for logging the metric. Default to every
      10 mins.

  Returns:
    Returns a LoggingMetricHook that saves tensor values in a JSON format.
  """
    if tensors_to_log is None:
        tensors_to_log = _TENSORS_TO_LOG
    return metric_hook.LoggingMetricHook(
        tensors=tensors_to_log,
        metric_logger=logger.get_benchmark_logger(),
        every_n_secs=every_n_secs)
Exemplo n.º 8
0
def get_logging_metric_hook(tensors_to_log=None, every_n_secs=600, **kwargs):  # pylint: disable=unused-argument
    """Function to get LoggingMetricHook.

  Args:
    tensors_to_log: List of tensor names or dictionary mapping labels to tensor
      names. If not set, log _TENSORS_TO_LOG by default.
    every_n_secs: `int`, the frequency for logging the metric. Default to every
      10 mins.

  Returns:
    Returns a ProfilerHook that writes out timelines that can be loaded into
    profiling tools like chrome://tracing.
  """
    if tensors_to_log is None:
        tensors_to_log = _TENSORS_TO_LOG
    return metric_hook.LoggingMetricHook(
        tensors=tensors_to_log,
        metric_logger=logger.get_benchmark_logger(),
        every_n_secs=every_n_secs)
Exemplo n.º 9
0
    def test_print_at_end_only(self):
        with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
            tf.compat.v1.train.get_or_create_global_step()
            t = tf.constant(42.0, name="foo")
            train_op = tf.constant(3)
            hook = metric_hook.LoggingMetricHook(
                tensors=[t.name], at_end=True, metric_logger=self._logger)
            hook.begin()
            mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
            sess.run(tf.compat.v1.global_variables_initializer())

            for _ in range(3):
                mon_sess.run(train_op)
                self.assertEqual(self._logger.logged_metric, [])

            hook.end(sess)
            self.assertEqual(len(self._logger.logged_metric), 1)
            metric = self._logger.logged_metric[0]
            self.assertRegexpMatches(metric["name"], "foo")
            self.assertEqual(metric["value"], 42.0)
            self.assertEqual(metric["unit"], None)
            self.assertEqual(metric["global_step"], 0)
 def test_illegal_args(self):
     with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'):
         metric_hook.LoggingMetricHook(tensors=['t'], every_n_iter=0)
     with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'):
         metric_hook.LoggingMetricHook(tensors=['t'], every_n_iter=-10)
     with self.assertRaisesRegexp(ValueError, 'xactly one of'):
         metric_hook.LoggingMetricHook(tensors=['t'],
                                       every_n_iter=5,
                                       every_n_secs=5)
     with self.assertRaisesRegexp(ValueError, 'xactly one of'):
         metric_hook.LoggingMetricHook(tensors=['t'])
     with self.assertRaisesRegexp(ValueError, 'log_dir and metric_logger'):
         metric_hook.LoggingMetricHook(tensors=['t'], every_n_iter=5)
     with self.assertRaisesRegexp(ValueError, 'log_dir and metric_logger'):
         metric_hook.LoggingMetricHook(tensors=['t'],
                                       every_n_iter=5,
                                       log_dir=self._log_dir,
                                       metric_logger=self._logger)