Exemplo n.º 1
0
 def test_illegal_args(self):
     with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"):
         metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=0)
     with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"):
         metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=-10)
     with self.assertRaisesRegexp(ValueError, "xactly one of"):
         metric_hook.LoggingMetricHook(tensors=["t"],
                                       every_n_iter=5,
                                       every_n_secs=5)
     with self.assertRaisesRegexp(ValueError, "xactly one of"):
         metric_hook.LoggingMetricHook(tensors=["t"])
     with self.assertRaisesRegexp(ValueError, "metric_logger"):
         metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=5)
Exemplo n.º 2
0
    def test_log_tensors(self):
        with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
            tf.compat.v1.train.get_or_create_global_step()
            t1 = tf.constant(42.0, name="foo")
            t2 = tf.constant(43.0, name="bar")
            train_op = tf.constant(3)
            hook = metric_hook.LoggingMetricHook(tensors=[t1, t2],
                                                 at_end=True,
                                                 metric_logger=self._logger)
            hook.begin()
            mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
            sess.run(tf.compat.v1.global_variables_initializer())

            for _ in range(3):
                mon_sess.run(train_op)
                self.assertEqual(self._logger.logged_metric, [])

            hook.end(sess)
            self.assertEqual(len(self._logger.logged_metric), 2)
            metric1 = self._logger.logged_metric[0]
            self.assertRegexpMatches(str(metric1["name"]), "foo")
            self.assertEqual(metric1["value"], 42.0)
            self.assertEqual(metric1["unit"], None)
            self.assertEqual(metric1["global_step"], 0)

            metric2 = self._logger.logged_metric[1]
            self.assertRegexpMatches(str(metric2["name"]), "bar")
            self.assertEqual(metric2["value"], 43.0)
            self.assertEqual(metric2["unit"], None)
            self.assertEqual(metric2["global_step"], 0)
Exemplo n.º 3
0
    def _validate_print_every_n_secs(self, sess, at_end):
        t = tf.constant(42.0, name="foo")
        train_op = tf.constant(3)

        hook = metric_hook.LoggingMetricHook(tensors=[t.name],
                                             every_n_secs=1.0,
                                             at_end=at_end,
                                             metric_logger=self._logger)
        hook.begin()
        mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
        sess.run(tf.compat.v1.global_variables_initializer())

        mon_sess.run(train_op)
        self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

        # assertNotRegexpMatches is not supported by python 3.1 and later
        self._logger.logged_metric = []
        mon_sess.run(train_op)
        self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
        time.sleep(1.0)

        self._logger.logged_metric = []
        mon_sess.run(train_op)
        self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

        self._logger.logged_metric = []
        hook.end(sess)
        if at_end:
            self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
        else:
            # assertNotRegexpMatches is not supported by python 3.1 and later
            self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
Exemplo n.º 4
0
    def test_global_step_not_found(self):
        with tf.Graph().as_default():
            t = tf.constant(42.0, name="foo")
            hook = metric_hook.LoggingMetricHook(tensors=[t.name],
                                                 at_end=True,
                                                 metric_logger=self._logger)

            with self.assertRaisesRegexp(
                    RuntimeError,
                    "should be created to use LoggingMetricHook."):
                hook.begin()
Exemplo n.º 5
0
def get_logging_metric_hook(tensors_to_log=None, every_n_secs=600, **kwargs):
    """Function to get LoggingMetricHook.

  Args:
    tensors_to_log: List of tensor names or dictionary mapping labels to tensor
      names. If not set, log _TENSORS_TO_LOG by default.
    every_n_secs: `int`, the frequency for logging the metric. Default to every
      10 mins.
    **kwargs: a dictionary of arguments.

  Returns:
    Returns a LoggingMetricHook that saves tensor values in a JSON format.
  """
    if tensors_to_log is None:
        tensors_to_log = _TENSORS_TO_LOG
    return metric_hook.LoggingMetricHook(
        tensors=tensors_to_log,
        metric_logger=logger.get_benchmark_logger(),
        every_n_secs=every_n_secs)
Exemplo n.º 6
0
    def _validate_print_every_n_steps(self, sess, at_end):
        t = tf.constant(42.0, name="foo")

        train_op = tf.constant(3)
        hook = metric_hook.LoggingMetricHook(tensors=[t.name],
                                             every_n_iter=10,
                                             at_end=at_end,
                                             metric_logger=self._logger)
        hook.begin()
        mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
        sess.run(tf.compat.v1.global_variables_initializer())
        mon_sess.run(train_op)
        self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
        for _ in range(3):
            self._logger.logged_metric = []
            for _ in range(9):
                mon_sess.run(train_op)
                # assertNotRegexpMatches is not supported by python 3.1 and later
                self.assertEqual(
                    str(self._logger.logged_metric).find(t.name), -1)
            mon_sess.run(train_op)
            self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

        # Add additional run to verify proper reset when called multiple times.
        self._logger.logged_metric = []
        mon_sess.run(train_op)
        # assertNotRegexpMatches is not supported by python 3.1 and later
        self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)

        self._logger.logged_metric = []
        hook.end(sess)
        if at_end:
            self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
        else:
            # assertNotRegexpMatches is not supported by python 3.1 and later
            self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)