예제 #1
0
  def test_step_counter_every_n_secs(self):
    with ops.Graph().as_default() as g, session_lib.Session() as sess:
      variables.get_or_create_global_step()
      train_op = training_util._increment_global_step(1)
      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
      hook = basic_session_run_hooks.StepCounterHook(
          summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1)

      hook.begin()
      sess.run(variables_lib.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      mon_sess.run(train_op)
      time.sleep(0.2)
      mon_sess.run(train_op)
      time.sleep(0.2)
      mon_sess.run(train_op)
      hook.end(sess)

      summary_writer.assert_summaries(
          test_case=self,
          expected_logdir=self.log_dir,
          expected_graph=g,
          expected_summaries={})
      self.assertTrue(summary_writer.summaries, 'No summaries were created.')
      self.assertItemsEqual([2, 3], summary_writer.summaries.keys())
      for summary in summary_writer.summaries.values():
        summary_value = summary[0].value[0]
        self.assertEqual('global_step/sec', summary_value.tag)
        self.assertGreater(summary_value.simple_value, 0)
예제 #2
0
  def _validate_print_every_n_steps(self, sess, at_end):
    t = tf.constant(42.0, name="foo")

    train_op = tf.constant(3)
    hook = metric_hook.LoggingMetricHook(
        tensors=[t.name], every_n_iter=10, at_end=at_end,
        metric_logger=self._logger)
    hook.begin()
    mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
    sess.run(tf.compat.v1.global_variables_initializer())
    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
    for _ in range(3):
      self._logger.logged_metric = []
      for _ in range(9):
        mon_sess.run(train_op)
        # assertNotRegexpMatches is not supported by python 3.1 and later
        self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
      mon_sess.run(train_op)
      self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

    # Add additional run to verify proper reset when called multiple times.
    self._logger.logged_metric = []
    mon_sess.run(train_op)
    # assertNotRegexpMatches is not supported by python 3.1 and later
    self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)

    self._logger.logged_metric = []
    hook.end(sess)
    if at_end:
      self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
    else:
      # assertNotRegexpMatches is not supported by python 3.1 and later
      self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
예제 #3
0
 def test_step_counter_every_n_steps(self):
   with ops.Graph().as_default() as g, session_lib.Session() as sess:
     variables.get_or_create_global_step()
     train_op = training_util._increment_global_step(1)
     summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
     hook = basic_session_run_hooks.StepCounterHook(
         summary_writer=summary_writer, every_n_steps=10)
     hook.begin()
     sess.run(variables_lib.global_variables_initializer())
     mon_sess = monitored_session._HookedSession(sess, [hook])
     with test.mock.patch.object(tf_logging, 'warning') as mock_log:
       for _ in range(30):
         time.sleep(0.01)
         mon_sess.run(train_op)
       # logging.warning should not be called.
       self.assertIsNone(mock_log.call_args)
     hook.end(sess)
     summary_writer.assert_summaries(
         test_case=self,
         expected_logdir=self.log_dir,
         expected_graph=g,
         expected_summaries={})
     self.assertItemsEqual([11, 21], summary_writer.summaries.keys())
     for step in [11, 21]:
       summary_value = summary_writer.summaries[step][0].value[0]
       self.assertEqual('global_step/sec', summary_value.tag)
       self.assertGreater(summary_value.simple_value, 0)
 def test_save_secs_calls_listeners_periodically(self):
     with self.graph.as_default():
         listener = MockCheckpointSaverListener()
         hook = basic_session_run_hooks.CheckpointSaverHook(
             self.model_dir,
             save_secs=2,
             scaffold=self.scaffold,
             listeners=[listener])
         hook.begin()
         self.scaffold.finalize()
         with session_lib.Session() as sess:
             sess.run(self.scaffold.init_op)
             mon_sess = monitored_session._HookedSession(sess, [hook])
             mon_sess.run(self.train_op)  # hook runs here
             mon_sess.run(self.train_op)
             time.sleep(2.5)
             mon_sess.run(self.train_op)  # hook runs here
             mon_sess.run(self.train_op)
             mon_sess.run(self.train_op)
             time.sleep(2.5)
             mon_sess.run(self.train_op)  # hook runs here
             mon_sess.run(
                 self.train_op)  # hook won't run here, so it does at end
             hook.end(sess)  # hook runs here
         self.assertEqual(
             {
                 'begin': 1,
                 'before_save': 4,
                 'after_save': 4,
                 'end': 1
             }, listener.get_counts())
 def test_save_steps_saves_periodically(self):
   with self.graph.as_default():
     hook = tf.train.CheckpointSaverHook(
         self.model_dir, save_steps=2, scaffold=self.scaffold)
     hook.begin()
     self.scaffold.finalize()
     with tf.Session() as sess:
       sess.run(self.scaffold.init_op)
       mon_sess = monitored_session._HookedSession(sess, [hook])
       mon_sess.run(self.train_op)
       mon_sess.run(self.train_op)
       # Not saved
       self.assertEqual(1, tf.contrib.framework.load_variable(
           self.model_dir, self.global_step.name))
       mon_sess.run(self.train_op)
       # saved
       self.assertEqual(3, tf.contrib.framework.load_variable(
           self.model_dir, self.global_step.name))
       mon_sess.run(self.train_op)
       # Not saved
       self.assertEqual(3, tf.contrib.framework.load_variable(
           self.model_dir, self.global_step.name))
       mon_sess.run(self.train_op)
       # saved
       self.assertEqual(5, tf.contrib.framework.load_variable(
           self.model_dir, self.global_step.name))
예제 #6
0
  def _validate_log_every_n_steps(self, sess, every_n_steps, warm_steps):
    hook = hooks.ExamplesPerSecondHook(
        batch_size=256,
        every_n_steps=every_n_steps,
        warm_steps=warm_steps)
    hook.begin()
    mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
    sess.run(tf.global_variables_initializer())

    self.logged_message = ''
    for _ in range(every_n_steps):
      mon_sess.run(self.train_op)
      self.assertEqual(str(self.logged_message).find('exp/sec'), -1)

    mon_sess.run(self.train_op)
    global_step_val = sess.run(self.global_step)
    # assertNotRegexpMatches is not supported by python 3.1 and later
    if global_step_val > warm_steps:
      self.assertRegexpMatches(str(self.logged_message), 'exp/sec')
    else:
      self.assertEqual(str(self.logged_message).find('exp/sec'), -1)

    # Add additional run to verify proper reset when called multiple times.
    self.logged_message = ''
    mon_sess.run(self.train_op)
    global_step_val = sess.run(self.global_step)
    if every_n_steps == 1 and global_step_val > warm_steps:
      self.assertRegexpMatches(str(self.logged_message), 'exp/sec')
    else:
      self.assertEqual(str(self.logged_message).find('exp/sec'), -1)

    hook.end(sess)
  def test_save_secs_saving_once_every_three_steps(self, mock_time):
    mock_time.return_value = 1484695987.209386
    hook = basic_session_run_hooks.SummarySaverHook(
        save_secs=9.,
        summary_writer=self.summary_writer,
        summary_op=self.summary_op)

    with self.test_session() as sess:
      hook.begin()
      sess.run(variables_lib.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      for _ in range(8):
        mon_sess.run(self.train_op)
        mock_time.return_value += 3.1
      hook.end(sess)

    # 24.8 seconds passed (3.1*8), it saves every 9 seconds starting from first:
    self.summary_writer.assert_summaries(
        test_case=self,
        expected_logdir=self.log_dir,
        expected_summaries={
            1: {
                'my_summary': 1.0
            },
            4: {
                'my_summary': 2.0
            },
            7: {
                'my_summary': 3.0
            },
        })
예제 #8
0
  def testDumpingDebugHookWithStatefulLegacyWatchFnWorks(self):
    watch_fn_state = {"run_counter": 0}

    def counting_watch_fn(fetches, feed_dict):
      del fetches, feed_dict
      watch_fn_state["run_counter"] += 1
      if watch_fn_state["run_counter"] % 2 == 1:
        # If odd-index run (1-based), watch everything.
        return "DebugIdentity", r".*", r".*"
      else:
        # If even-index run, watch nothing.
        return "DebugIdentity", r"$^", r"$^"

    dumping_hook = hooks.DumpingDebugHook(
        self.session_root, watch_fn=counting_watch_fn, log_usage=False)
    mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
    for _ in range(4):
      mon_sess.run(self.inc_v)

    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
    dump_dirs = sorted(
        dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1]))
    self.assertEqual(4, len(dump_dirs))

    for i, dump_dir in enumerate(dump_dirs):
      self._assert_correct_run_subdir_naming(os.path.basename(dump_dir))
      dump = debug_data.DebugDumpDir(dump_dir)
      if i % 2 == 0:
        self.assertAllClose([10.0 + 1.0 * i],
                            dump.get_tensors("v", 0, "DebugIdentity"))
      else:
        self.assertEqual(0, dump.size)

      self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
      self.assertEqual(repr(None), dump.run_feed_keys_info)
예제 #9
0
  def _validate_print_every_n_steps(self, sess, at_end):
    t = tf.constant(42.0, name="foo")

    train_op = tf.constant(3)
    hook = metric_hook.LoggingMetricHook(
        tensors=[t.name], every_n_iter=10, at_end=at_end,
        metric_logger=self._logger)
    hook.begin()
    mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
    sess.run(tf.compat.v1.global_variables_initializer())
    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
    for _ in range(3):
      self._logger.logged_metric = []
      for _ in range(9):
        mon_sess.run(train_op)
        # assertNotRegexpMatches is not supported by python 3.1 and later
        self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
      mon_sess.run(train_op)
      self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

    # Add additional run to verify proper reset when called multiple times.
    self._logger.logged_metric = []
    mon_sess.run(train_op)
    # assertNotRegexpMatches is not supported by python 3.1 and later
    self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)

    self._logger.logged_metric = []
    hook.end(sess)
    if at_end:
      self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
    else:
      # assertNotRegexpMatches is not supported by python 3.1 and later
      self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
예제 #10
0
    def test_summary_writer_defs(self):
        testing.FakeSummaryWriter.install()
        tf.train.SummaryWriterCache.clear()
        summary_writer = tf.train.SummaryWriterCache.get(self.model_dir)

        with self.graph.as_default():
            hook = tf.train.CheckpointSaverHook(self.model_dir,
                                                save_steps=2,
                                                scaffold=self.scaffold)
            hook.begin()
            self.scaffold.finalize()
            with tf.Session() as sess:
                sess.run(self.scaffold.init_op)
                mon_sess = monitored_session._HookedSession(sess, [hook])
                mon_sess.run(self.train_op)
            summary_writer.assert_summaries(
                test_case=self,
                expected_logdir=self.model_dir,
                expected_added_meta_graphs=[
                    meta_graph.create_meta_graph_def(
                        graph_def=self.graph.as_graph_def(add_shapes=True),
                        saver_def=self.scaffold.saver.saver_def)
                ])

        testing.FakeSummaryWriter.uninstall()
예제 #11
0
  def _validate_print_every_n_secs(self, sess, at_end):
    t = tf.constant(42.0, name='foo')
    train_op = tf.constant(3)

    hook = metric_hook.LoggingMetricHook(
        tensors=[t.name], every_n_secs=1.0, at_end=at_end,
        metric_logger=self._logger)
    hook.begin()
    mon_sess = monitored_session._HookedSession(sess, [hook])
    sess.run(tf.global_variables_initializer())

    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

    # assertNotRegexpMatches is not supported by python 3.1 and later
    self._logger.logged_metric = []
    mon_sess.run(train_op)
    self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
    time.sleep(1.0)

    self._logger.logged_metric = []
    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

    self._logger.logged_metric = []
    hook.end(sess)
    if at_end:
      self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
    else:
      # assertNotRegexpMatches is not supported by python 3.1 and later
      self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
예제 #12
0
  def test_log_tensors(self):
    with tf.Graph().as_default(), tf.Session() as sess:
      tf.train.get_or_create_global_step()
      t1 = tf.constant(42.0, name='foo')
      t2 = tf.constant(43.0, name='bar')
      train_op = tf.constant(3)
      hook = metric_hook.LoggingMetricHook(
          tensors=[t1, t2], at_end=True, metric_logger=self._logger)
      hook.begin()
      mon_sess = monitored_session._HookedSession(sess, [hook])
      sess.run(tf.global_variables_initializer())

      for _ in range(3):
        mon_sess.run(train_op)
        self.assertEqual(self._logger.logged_metric, [])

      hook.end(sess)
      self.assertEqual(len(self._logger.logged_metric), 2)
      metric1 = self._logger.logged_metric[0]
      self.assertRegexpMatches(str(metric1["name"]), "foo")
      self.assertEqual(metric1["value"], 42.0)
      self.assertEqual(metric1["unit"], None)
      self.assertEqual(metric1["global_step"], 0)

      metric2 = self._logger.logged_metric[1]
      self.assertRegexpMatches(str(metric2["name"]), "bar")
      self.assertEqual(metric2["value"], 43.0)
      self.assertEqual(metric2["unit"], None)
      self.assertEqual(metric2["global_step"], 0)
  def testBothHooksAndUserHaveFeeds(self):
    with tf.Graph().as_default(), tf.Session() as sess:
      mock_hook = FakeHook()
      mock_hook2 = FakeHook()
      mon_sess = monitored_session._HookedSession(
          sess=sess, hooks=[mock_hook, mock_hook2])
      a_tensor = tf.constant([0], name='a_tensor')
      b_tensor = tf.constant([0], name='b_tensor')
      c_tensor = tf.constant([0], name='c_tensor')
      add_tensor = a_tensor + b_tensor + c_tensor
      mock_hook.request = tf.train.SessionRunArgs(
          None, feed_dict={
              a_tensor: [5]
          })
      mock_hook2.request = tf.train.SessionRunArgs(
          None, feed_dict={
              b_tensor: [10]
          })
      sess.run(tf.global_variables_initializer())

      feed_dict = {c_tensor: [20]}
      self.assertEqual(
          mon_sess.run(fetches=add_tensor, feed_dict=feed_dict), [35])
      # User feed_dict should not be changed
      self.assertEqual(len(feed_dict), 1)
    def test_multiple_summaries(self):
        hook = basic_session_run_hooks.SummarySaverHook(
            save_steps=8,
            summary_writer=self.summary_writer,
            summary_op=[self.summary_op, self.summary_op2])

        with self.test_session() as sess:
            hook.begin()
            sess.run(variables_lib.global_variables_initializer())
            mon_sess = monitored_session._HookedSession(sess, [hook])
            for _ in range(10):
                mon_sess.run(self.train_op)
            hook.end(sess)

        self.summary_writer.assert_summaries(test_case=self,
                                             expected_logdir=self.log_dir,
                                             expected_summaries={
                                                 1: {
                                                     'my_summary': 1.0,
                                                     'my_summary2': 2.0
                                                 },
                                                 9: {
                                                     'my_summary': 2.0,
                                                     'my_summary2': 4.0
                                                 },
                                             })
예제 #15
0
  def test_capture(self):
    global_step = tf.contrib.framework.get_or_create_global_step()
    # Some test computation
    some_weights = tf.get_variable("weigths", [2, 128])
    computation = tf.nn.softmax(some_weights)

    hook = hooks.MetadataCaptureHook(
        params={"step": 5}, model_dir=self.model_dir,
        run_config=tf.contrib.learn.RunConfig())
    hook.begin()

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      #pylint: disable=W0212
      mon_sess = monitored_session._HookedSession(sess, [hook])
      # Should not trigger for step 0
      sess.run(tf.assign(global_step, 0))
      mon_sess.run(computation)
      self.assertEqual(gfile.ListDirectory(self.model_dir), [])
      # Should trigger *after* step 5
      sess.run(tf.assign(global_step, 5))
      mon_sess.run(computation)
      self.assertEqual(gfile.ListDirectory(self.model_dir), [])
      mon_sess.run(computation)
      self.assertEqual(
          set(gfile.ListDirectory(self.model_dir)),
          set(["run_meta", "tfprof_log", "timeline.json"]))
    def test_step_counter_every_n_secs(self):
        with tf.Graph().as_default() as g, tf.Session() as sess:
            global_step = tf.contrib.framework.get_or_create_global_step()
            train_op = tf.assign_add(global_step, 1)
            summary_writer = testing.FakeSummaryWriter(self.log_dir, g)
            hook = tf.train.StepCounterHook(summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1)

            hook.begin()
            sess.run(tf.initialize_all_variables())
            mon_sess = monitored_session._HookedSession(sess, [hook])
            mon_sess.run(train_op)
            time.sleep(0.2)
            mon_sess.run(train_op)
            time.sleep(0.2)
            mon_sess.run(train_op)
            hook.end(sess)

            summary_writer.assert_summaries(
                test_case=self, expected_logdir=self.log_dir, expected_graph=g, expected_summaries={}
            )
            self.assertTrue(summary_writer.summaries, "No summaries were created.")
            self.assertItemsEqual([2, 3], summary_writer.summaries.keys())
            for summary in summary_writer.summaries.values():
                summary_value = summary[0].value[0]
                self.assertEqual("global_step/sec", summary_value.tag)
                self.assertGreater(summary_value.simple_value, 0)
  def _validate_print_every_n_steps(self, sess, at_end):
    t = constant_op.constant(42.0, name='foo')

    train_op = constant_op.constant(3)
    hook = basic_session_run_hooks.LoggingTensorHook(
        tensors=[t.name], every_n_iter=10, at_end=at_end)
    hook.begin()
    mon_sess = monitored_session._HookedSession(sess, [hook])
    sess.run(variables_lib.global_variables_initializer())
    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self.logged_message), t.name)
    for _ in range(3):
      self.logged_message = ''
      for _ in range(9):
        mon_sess.run(train_op)
        # assertNotRegexpMatches is not supported by python 3.1 and later
        self.assertEqual(str(self.logged_message).find(t.name), -1)
      mon_sess.run(train_op)
      self.assertRegexpMatches(str(self.logged_message), t.name)

    # Add additional run to verify proper reset when called multiple times.
    self.logged_message = ''
    mon_sess.run(train_op)
    # assertNotRegexpMatches is not supported by python 3.1 and later
    self.assertEqual(str(self.logged_message).find(t.name), -1)

    self.logged_message = ''
    hook.end(sess)
    if at_end:
      self.assertRegexpMatches(str(self.logged_message), t.name)
    else:
      # assertNotRegexpMatches is not supported by python 3.1 and later
      self.assertEqual(str(self.logged_message).find(t.name), -1)
 def test_summary_saver(self):
   with tf.Graph().as_default() as g, tf.Session() as sess:
     log_dir = 'log/dir'
     summary_writer = testing.FakeSummaryWriter(log_dir, g)
     var = tf.Variable(0.0)
     tensor = tf.assign_add(var, 1.0)
     summary_op = tf.scalar_summary('my_summary', tensor)
     global_step = tf.contrib.framework.get_or_create_global_step()
     train_op = tf.assign_add(global_step, 1)
     hook = tf.train.SummarySaverHook(
         summary_op=summary_op, save_steps=8, summary_writer=summary_writer)
     hook.begin()
     sess.run(tf.initialize_all_variables())
     mon_sess = monitored_session._HookedSession(sess, [hook])
     for i in range(30):
       _ = i
       mon_sess.run(train_op)
     hook.end(sess)
     summary_writer.assert_summaries(
         test_case=self,
         expected_logdir=log_dir,
         expected_graph=g,
         expected_summaries={
             1: {'my_summary': 1.0},
             9: {'my_summary': 2.0},
             17: {'my_summary': 3.0},
             25: {'my_summary': 4.0},
         })
 def DISABLED_test_save_secs_calls_listeners_periodically(self):
   with self.graph.as_default():
     listener = MockCheckpointSaverListener()
     hook = basic_session_run_hooks.CheckpointSaverHook(
         self.model_dir,
         save_secs=2,
         scaffold=self.scaffold,
         listeners=[listener])
     hook.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       mon_sess = monitored_session._HookedSession(sess, [hook])
       mon_sess.run(self.train_op)  # hook runs here
       mon_sess.run(self.train_op)
       time.sleep(2.5)
       mon_sess.run(self.train_op)  # hook runs here
       mon_sess.run(self.train_op)
       mon_sess.run(self.train_op)
       time.sleep(2.5)
       mon_sess.run(self.train_op)  # hook runs here
       mon_sess.run(self.train_op)  # hook won't run here, so it does at end
       hook.end(sess)  # hook runs here
     self.assertEqual({
         'begin': 1,
         'before_save': 4,
         'after_save': 4,
         'end': 1
     }, listener.get_counts())
예제 #20
0
  def testDumpingDebugHookWithStatefulLegacyWatchFnWorks(self):
    watch_fn_state = {"run_counter": 0}

    def counting_watch_fn(fetches, feed_dict):
      del fetches, feed_dict
      watch_fn_state["run_counter"] += 1
      if watch_fn_state["run_counter"] % 2 == 1:
        # If odd-index run (1-based), watch everything.
        return "DebugIdentity", r".*", r".*"
      else:
        # If even-index run, watch nothing.
        return "DebugIdentity", r"$^", r"$^"

    dumping_hook = hooks.DumpingDebugHook(
        self.session_root, watch_fn=counting_watch_fn, log_usage=False)
    mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
    for _ in range(4):
      mon_sess.run(self.inc_v)

    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
    dump_dirs = sorted(
        dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1]))
    self.assertEqual(4, len(dump_dirs))

    for i, dump_dir in enumerate(dump_dirs):
      self._assert_correct_run_subdir_naming(os.path.basename(dump_dir))
      dump = debug_data.DebugDumpDir(dump_dir)
      if i % 2 == 0:
        self.assertAllClose([10.0 + 1.0 * i],
                            dump.get_tensors("v", 0, "DebugIdentity"))
      else:
        self.assertEqual(0, dump.size)

      self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
      self.assertEqual(repr(None), dump.run_feed_keys_info)
예제 #21
0
  def testBothHooksAndUserHaveFeeds(self):
    with tf.Graph().as_default(), tf.Session() as sess:
      mock_hook = FakeHook()
      mock_hook2 = FakeHook()
      mon_sess = monitored_session._HookedSession(
          sess=sess, hooks=[mock_hook, mock_hook2])
      a_tensor = tf.constant([0], name='a_tensor')
      b_tensor = tf.constant([0], name='b_tensor')
      c_tensor = tf.constant([0], name='c_tensor')
      add_tensor = a_tensor + b_tensor + c_tensor
      mock_hook.request = tf.train.SessionRunArgs(
          None, feed_dict={
              a_tensor: [5]
          })
      mock_hook2.request = tf.train.SessionRunArgs(
          None, feed_dict={
              b_tensor: [10]
          })
      sess.run(tf.initialize_all_variables())

      feed_dict = {c_tensor: [20]}
      self.assertEqual(
          mon_sess.run(fetches=add_tensor, feed_dict=feed_dict), [35])
      # User feed_dict should not be changed
      self.assertEqual(len(feed_dict), 1)
  def test_step_counter_every_n_secs(self):
    with ops.Graph().as_default() as g, session_lib.Session() as sess:
      global_step = variables.get_or_create_global_step()
      train_op = state_ops.assign_add(global_step, 1)
      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
      hook = basic_session_run_hooks.StepCounterHook(
          summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1)

      hook.begin()
      sess.run(variables_lib.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      mon_sess.run(train_op)
      time.sleep(0.2)
      mon_sess.run(train_op)
      time.sleep(0.2)
      mon_sess.run(train_op)
      hook.end(sess)

      summary_writer.assert_summaries(
          test_case=self,
          expected_logdir=self.log_dir,
          expected_graph=g,
          expected_summaries={})
      self.assertTrue(summary_writer.summaries, 'No summaries were created.')
      self.assertItemsEqual([2, 3], summary_writer.summaries.keys())
      for summary in summary_writer.summaries.values():
        summary_value = summary[0].value[0]
        self.assertEqual('global_step/sec', summary_value.tag)
        self.assertGreater(summary_value.simple_value, 0)
  def test_save_secs_saving_once_every_three_steps(self):
    hook = basic_session_run_hooks.SummarySaverHook(
        save_secs=0.9,
        summary_writer=self.summary_writer,
        summary_op=self.summary_op)

    with self.test_session() as sess:
      hook.begin()
      sess.run(variables_lib.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      for _ in range(8):
        mon_sess.run(self.train_op)
        time.sleep(0.3)
      hook.end(sess)

    self.summary_writer.assert_summaries(
        test_case=self,
        expected_logdir=self.log_dir,
        expected_summaries={
            1: {
                'my_summary': 1.0
            },
            4: {
                'my_summary': 2.0
            },
            7: {
                'my_summary': 3.0
            },
        })
  def test_multiple_summaries(self):
    hook = basic_session_run_hooks.SummarySaverHook(
        save_steps=8,
        summary_writer=self.summary_writer,
        summary_op=[self.summary_op, self.summary_op2])

    with self.test_session() as sess:
      hook.begin()
      sess.run(variables_lib.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      for _ in range(10):
        mon_sess.run(self.train_op)
      hook.end(sess)

    self.summary_writer.assert_summaries(
        test_case=self,
        expected_logdir=self.log_dir,
        expected_summaries={
            1: {
                'my_summary': 1.0,
                'my_summary2': 2.0
            },
            9: {
                'my_summary': 2.0,
                'my_summary2': 4.0
            },
        })
예제 #25
0
    def test_log_tensors(self):
        with tf.Graph().as_default(), tf.Session() as sess:
            tf.train.get_or_create_global_step()
            t1 = tf.constant(42.0, name='foo')
            t2 = tf.constant(43.0, name='bar')
            train_op = tf.constant(3)
            hook = metric_hook.LoggingMetricHook(tensors=[t1, t2],
                                                 at_end=True,
                                                 metric_logger=self._logger)
            hook.begin()
            mon_sess = monitored_session._HookedSession(sess, [hook])
            sess.run(tf.global_variables_initializer())

            for _ in range(3):
                mon_sess.run(train_op)
                self.assertEqual(self._logger.logged_metric, [])

            hook.end(sess)
            self.assertEqual(len(self._logger.logged_metric), 2)
            metric1 = self._logger.logged_metric[0]
            self.assertRegexpMatches(str(metric1["name"]), "foo")
            self.assertEqual(metric1["value"], 42.0)
            self.assertEqual(metric1["unit"], None)
            self.assertEqual(metric1["global_step"], 0)

            metric2 = self._logger.logged_metric[1]
            self.assertRegexpMatches(str(metric2["name"]), "bar")
            self.assertEqual(metric2["value"], 43.0)
            self.assertEqual(metric2["unit"], None)
            self.assertEqual(metric2["global_step"], 0)
예제 #26
0
    def _validate_print_every_n_secs(self, sess, at_end):
        t = tf.constant(42.0, name='foo')
        train_op = tf.constant(3)

        hook = metric_hook.LoggingMetricHook(tensors=[t.name],
                                             every_n_secs=1.0,
                                             at_end=at_end,
                                             metric_logger=self._logger)
        hook.begin()
        mon_sess = monitored_session._HookedSession(sess, [hook])
        sess.run(tf.global_variables_initializer())

        mon_sess.run(train_op)
        self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

        # assertNotRegexpMatches is not supported by python 3.1 and later
        self._logger.logged_metric = []
        mon_sess.run(train_op)
        self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
        time.sleep(1.0)

        self._logger.logged_metric = []
        mon_sess.run(train_op)
        self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

        self._logger.logged_metric = []
        hook.end(sess)
        if at_end:
            self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
        else:
            # assertNotRegexpMatches is not supported by python 3.1 and later
            self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
    def _validate_print_every_n_secs(self, sess, at_end):
        t = constant_op.constant(42.0, name='foo')
        train_op = constant_op.constant(3)

        hook = basic_session_run_hooks.LoggingTensorHook(tensors=[t.name],
                                                         every_n_secs=1.0,
                                                         at_end=at_end)
        hook.begin()
        mon_sess = monitored_session._HookedSession(sess, [hook])
        sess.run(variables_lib.global_variables_initializer())

        mon_sess.run(train_op)
        self.assertRegexpMatches(str(self.logged_message), t.name)

        # assertNotRegexpMatches is not supported by python 3.1 and later
        self.logged_message = ''
        mon_sess.run(train_op)
        self.assertEqual(str(self.logged_message).find(t.name), -1)
        time.sleep(1.0)

        self.logged_message = ''
        mon_sess.run(train_op)
        self.assertRegexpMatches(str(self.logged_message), t.name)

        self.logged_message = ''
        hook.end(sess)
        if at_end:
            self.assertRegexpMatches(str(self.logged_message), t.name)
        else:
            # assertNotRegexpMatches is not supported by python 3.1 and later
            self.assertEqual(str(self.logged_message).find(t.name), -1)
    def test_save_steps(self):
        hook = tf.train.SummarySaverHook(save_steps=8,
                                         summary_writer=self.summary_writer,
                                         summary_op=self.summary_op)

        with self.test_session() as sess:
            hook.begin()
            sess.run(tf.initialize_all_variables())
            mon_sess = monitored_session._HookedSession(sess, [hook])
            for _ in range(30):
                mon_sess.run(self.train_op)
            hook.end(sess)

        self.summary_writer.assert_summaries(test_case=self,
                                             expected_logdir=self.log_dir,
                                             expected_summaries={
                                                 1: {
                                                     'my_summary': 1.0
                                                 },
                                                 9: {
                                                     'my_summary': 2.0
                                                 },
                                                 17: {
                                                     'my_summary': 3.0
                                                 },
                                                 25: {
                                                     'my_summary': 4.0
                                                 },
                                             })
    def _validate_print_every_n_steps(self, sess, at_end):
        t = constant_op.constant(42.0, name='foo')

        train_op = constant_op.constant(3)
        hook = basic_session_run_hooks.LoggingTensorHook(tensors=[t.name],
                                                         every_n_iter=10,
                                                         at_end=at_end)
        hook.begin()
        mon_sess = monitored_session._HookedSession(sess, [hook])
        sess.run(variables_lib.global_variables_initializer())
        mon_sess.run(train_op)
        self.assertRegexpMatches(str(self.logged_message), t.name)
        for _ in range(3):
            self.logged_message = ''
            for _ in range(9):
                mon_sess.run(train_op)
                # assertNotRegexpMatches is not supported by python 3.1 and later
                self.assertEqual(str(self.logged_message).find(t.name), -1)
            mon_sess.run(train_op)
            self.assertRegexpMatches(str(self.logged_message), t.name)

        # Add additional run to verify proper reset when called multiple times.
        self.logged_message = ''
        mon_sess.run(train_op)
        # assertNotRegexpMatches is not supported by python 3.1 and later
        self.assertEqual(str(self.logged_message).find(t.name), -1)

        self.logged_message = ''
        hook.end(sess)
        if at_end:
            self.assertRegexpMatches(str(self.logged_message), t.name)
        else:
            # assertNotRegexpMatches is not supported by python 3.1 and later
            self.assertEqual(str(self.logged_message).find(t.name), -1)
    def test_stop_based_on_num_step(self):
        h = basic_session_run_hooks.StopAtStepHook(num_steps=10)

        with ops.Graph().as_default():
            global_step = variables.get_or_create_global_step()
            no_op = control_flow_ops.no_op()
            h.begin()
            with session_lib.Session() as sess:
                mon_sess = monitored_session._HookedSession(sess, [h])
                sess.run(state_ops.assign(global_step, 5))
                h.after_create_session(sess, None)
                mon_sess.run(no_op)
                self.assertFalse(mon_sess.should_stop())
                sess.run(state_ops.assign(global_step, 13))
                mon_sess.run(no_op)
                self.assertFalse(mon_sess.should_stop())
                sess.run(state_ops.assign(global_step, 14))
                mon_sess.run(no_op)
                self.assertFalse(mon_sess.should_stop())
                sess.run(state_ops.assign(global_step, 15))
                mon_sess.run(no_op)
                self.assertTrue(mon_sess.should_stop())
                sess.run(state_ops.assign(global_step, 16))
                mon_sess._should_stop = False
                mon_sess.run(no_op)
                self.assertTrue(mon_sess.should_stop())
예제 #31
0
  def _validate_log_every_n_steps(self, sess, every_n_steps, warm_steps):
    hook = hooks.ExamplesPerSecondHook(
        batch_size=256,
        every_n_steps=every_n_steps,
        warm_steps=warm_steps)
    hook.begin()
    mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
    sess.run(tf.global_variables_initializer())

    self.logged_message = ''
    for _ in range(every_n_steps):
      mon_sess.run(self.train_op)
      self.assertEqual(str(self.logged_message).find('exp/sec'), -1)

    mon_sess.run(self.train_op)
    global_step_val = sess.run(self.global_step)
    # assertNotRegexpMatches is not supported by python 3.1 and later
    if global_step_val > warm_steps:
      self.assertRegexpMatches(str(self.logged_message), 'exp/sec')
    else:
      self.assertEqual(str(self.logged_message).find('exp/sec'), -1)

    # Add additional run to verify proper reset when called multiple times.
    self.logged_message = ''
    mon_sess.run(self.train_op)
    global_step_val = sess.run(self.global_step)
    if every_n_steps == 1 and global_step_val > warm_steps:
      self.assertRegexpMatches(str(self.logged_message), 'exp/sec')
    else:
      self.assertEqual(str(self.logged_message).find('exp/sec'), -1)

    hook.end(sess)
예제 #32
0
 def test_save_steps_saves_periodically(self):
     with self.graph.as_default():
         hook = tf.train.CheckpointSaverHook(self.model_dir,
                                             save_steps=2,
                                             scaffold=self.scaffold)
         hook.begin()
         self.scaffold.finalize()
         with tf.Session() as sess:
             sess.run(self.scaffold.init_op)
             mon_sess = monitored_session._HookedSession(sess, [hook])
             mon_sess.run(self.train_op)
             mon_sess.run(self.train_op)
             # Not saved
             self.assertEqual(
                 1,
                 tf.contrib.framework.load_variable(self.model_dir,
                                                    self.global_step.name))
             mon_sess.run(self.train_op)
             # saved
             self.assertEqual(
                 3,
                 tf.contrib.framework.load_variable(self.model_dir,
                                                    self.global_step.name))
             mon_sess.run(self.train_op)
             # Not saved
             self.assertEqual(
                 3,
                 tf.contrib.framework.load_variable(self.model_dir,
                                                    self.global_step.name))
             mon_sess.run(self.train_op)
             # saved
             self.assertEqual(
                 5,
                 tf.contrib.framework.load_variable(self.model_dir,
                                                    self.global_step.name))
예제 #33
0
  def testTensorBoardDebugHookDisablingTracebackSourceCodeSendingWorks(self):
    u = variables.VariableV1(2.1, name="u")
    v = variables.VariableV1(20.0, name="v")
    w = math_ops.multiply(u, v, name="w")

    sess = session.Session(
        config=session_debug_testlib.no_rewrite_session_config())
    sess.run(variables.global_variables_initializer())

    grpc_debug_hook = hooks.TensorBoardDebugHook(
        ["localhost:%d" % self._server_port],
        send_traceback_and_source_code=False)
    sess = monitored_session._HookedSession(sess, [grpc_debug_hook])

    # Activate watch point on a tensor before calling sess.run().
    self._server.request_watch("u/read", 0, "DebugIdentity")
    self.assertAllClose(42.0, sess.run(w))

    # Check that the server has _not_ received any tracebacks, as a result of
    # the disabling above.
    with self.assertRaisesRegex(ValueError, r"Op .*u/read.* does not exist"):
      self.assertTrue(self._server.query_op_traceback("u/read"))
    with self.assertRaisesRegex(ValueError,
                                r".* has not received any source file"):
      self._server.query_source_file_line(__file__, 1)
예제 #34
0
    def test_step_counter_every_n_secs(self):
        with tf.Graph().as_default() as g, tf.Session() as sess:
            global_step = tf.contrib.framework.get_or_create_global_step()
            train_op = tf.assign_add(global_step, 1)
            summary_writer = testing.FakeSummaryWriter(self.log_dir, g)
            hook = tf.train.StepCounterHook(summary_writer=summary_writer,
                                            every_n_steps=None,
                                            every_n_secs=0.1)

            hook.begin()
            sess.run(tf.global_variables_initializer())
            mon_sess = monitored_session._HookedSession(sess, [hook])
            mon_sess.run(train_op)
            time.sleep(0.2)
            mon_sess.run(train_op)
            time.sleep(0.2)
            mon_sess.run(train_op)
            hook.end(sess)

            summary_writer.assert_summaries(test_case=self,
                                            expected_logdir=self.log_dir,
                                            expected_graph=g,
                                            expected_summaries={})
            self.assertTrue(summary_writer.summaries,
                            'No summaries were created.')
            self.assertItemsEqual([2, 3], summary_writer.summaries.keys())
            for summary in summary_writer.summaries.values():
                summary_value = summary[0].value[0]
                self.assertEqual('global_step/sec', summary_value.tag)
                self.assertGreater(summary_value.simple_value, 0)
예제 #35
0
  def test_save_secs_saving_once_every_step(self):
    hook = tf.train.SummarySaverHook(
        save_steps=None,
        save_secs=0.5,
        summary_writer=self.summary_writer,
        summary_op=self.summary_op)

    with self.test_session() as sess:
      hook.begin()
      sess.run(tf.initialize_all_variables())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      for _ in range(4):
        mon_sess.run(self.train_op)
        time.sleep(0.5)
      hook.end(sess)

    self.summary_writer.assert_summaries(
        test_case=self,
        expected_logdir=self.log_dir,
        expected_summaries={
            1: {'my_summary': 1.0},
            2: {'my_summary': 2.0},
            3: {'my_summary': 3.0},
            4: {'my_summary': 4.0},
        })
예제 #36
0
    def test_capture(self):
        global_step = tf.contrib.framework.get_or_create_global_step()
        # Some test computation
        some_weights = tf.get_variable("weigths", [2, 128])
        computation = tf.nn.softmax(some_weights)

        hook = hooks.MetadataCaptureHook(params={"step": 5},
                                         model_dir=self.model_dir)
        hook.begin()

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            #pylint: disable=W0212
            mon_sess = monitored_session._HookedSession(sess, [hook])
            # Should not trigger for step 0
            sess.run(tf.assign(global_step, 0))
            mon_sess.run(computation)
            self.assertEqual(gfile.ListDirectory(self.model_dir), [])
            # Should trigger *after* step 5
            sess.run(tf.assign(global_step, 5))
            mon_sess.run(computation)
            self.assertEqual(gfile.ListDirectory(self.model_dir), [])
            mon_sess.run(computation)
            self.assertEqual(set(gfile.ListDirectory(self.model_dir)),
                             set(["run_meta", "tfprof_log", "timeline.json"]))
 def test_save_secs_saves_periodically(self):
     with self.graph.as_default():
         hook = basic_session_run_hooks.CheckpointSaverHook(
             self.model_dir, save_secs=2, scaffold=self.scaffold)
         hook.begin()
         self.scaffold.finalize()
         with session_lib.Session() as sess:
             sess.run(self.scaffold.init_op)
             mon_sess = monitored_session._HookedSession(sess, [hook])
             mon_sess.run(self.train_op)  # Saved.
             mon_sess.run(self.train_op)  # Not saved.
             self.assertEqual(
                 1,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
             time.sleep(2.5)
             mon_sess.run(self.train_op)  # Saved.
             mon_sess.run(self.train_op)  # Not saved.
             mon_sess.run(self.train_op)  # Not saved.
             self.assertEqual(
                 3,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
             time.sleep(2.5)
             mon_sess.run(self.train_op)  # Saved.
             self.assertEqual(
                 6,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
  def test_global_step_name(self):
    with ops.Graph().as_default() as g, session_lib.Session() as sess:
      with variable_scope.variable_scope('bar'):
        foo_step = variable_scope.get_variable(
            'foo',
            initializer=0,
            trainable=False,
            collections=[
                ops.GraphKeys.GLOBAL_STEP, ops.GraphKeys.GLOBAL_VARIABLES
            ])
      train_op = state_ops.assign_add(foo_step, 1)
      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
      hook = basic_session_run_hooks.StepCounterHook(
          summary_writer=summary_writer, every_n_steps=1, every_n_secs=None)

      hook.begin()
      sess.run(variables_lib.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      mon_sess.run(train_op)
      mon_sess.run(train_op)
      hook.end(sess)

      summary_writer.assert_summaries(
          test_case=self,
          expected_logdir=self.log_dir,
          expected_graph=g,
          expected_summaries={})
      self.assertTrue(summary_writer.summaries, 'No summaries were created.')
      self.assertItemsEqual([2], summary_writer.summaries.keys())
      summary_value = summary_writer.summaries[2][0].value[0]
      self.assertEqual('bar/foo/sec', summary_value.tag)
예제 #39
0
  def test_save_secs_saving_once_every_three_steps(self, mock_time):
    mock_time.return_value = 1484695987.209386
    hook = basic_session_run_hooks.SummarySaverHook(
        save_secs=9.,
        summary_writer=self.summary_writer,
        summary_op=self.summary_op)

    with self.test_session() as sess:
      hook.begin()
      sess.run(variables_lib.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      for _ in range(8):
        mon_sess.run(self.train_op)
        mock_time.return_value += 3.1
      hook.end(sess)

    # 24.8 seconds passed (3.1*8), it saves every 9 seconds starting from first:
    self.summary_writer.assert_summaries(
        test_case=self,
        expected_logdir=self.log_dir,
        expected_summaries={
            1: {
                'my_summary': 1.0
            },
            4: {
                'my_summary': 2.0
            },
            7: {
                'my_summary': 3.0
            },
        })
  def testTensorBoardDebugHookDisablingTracebackSourceCodeSendingWorks(self):
    u = variables.Variable(2.1, name="u")
    v = variables.Variable(20.0, name="v")
    w = math_ops.multiply(u, v, name="w")

    sess = session.Session(config=no_rewrite_session_config())
    sess.run(variables.global_variables_initializer())

    grpc_debug_hook = hooks.TensorBoardDebugHook(
        ["localhost:%d" % self._server_port],
        send_traceback_and_source_code=False)
    sess = monitored_session._HookedSession(sess, [grpc_debug_hook])

    # Activate watch point on a tensor before calling sess.run().
    self._server.request_watch("u/read", 0, "DebugIdentity")
    self.assertAllClose(42.0, sess.run(w))

    # Check that the server has _not_ received any tracebacks, as a result of
    # the disabling above.
    with self.assertRaisesRegexp(
        ValueError, r"Op .*u/read.* does not exist"):
      self.assertTrue(self._server.query_op_traceback("u/read"))
    with self.assertRaisesRegexp(
        ValueError, r".* has not received any source file"):
      self._server.query_source_file_line(__file__, 1)
    def test_global_step_name(self):
        with ops.Graph().as_default() as g, session_lib.Session() as sess:
            with variable_scope.variable_scope('bar'):
                foo_step = variable_scope.get_variable(
                    'foo',
                    initializer=0,
                    trainable=False,
                    collections=[
                        ops.GraphKeys.GLOBAL_STEP,
                        ops.GraphKeys.GLOBAL_VARIABLES
                    ])
            train_op = state_ops.assign_add(foo_step, 1)
            summary_writer = fake_summary_writer.FakeSummaryWriter(
                self.log_dir, g)
            hook = basic_session_run_hooks.StepCounterHook(
                summary_writer=summary_writer,
                every_n_steps=1,
                every_n_secs=None)

            hook.begin()
            sess.run(variables_lib.global_variables_initializer())
            mon_sess = monitored_session._HookedSession(sess, [hook])
            mon_sess.run(train_op)
            mon_sess.run(train_op)
            hook.end(sess)

            summary_writer.assert_summaries(test_case=self,
                                            expected_logdir=self.log_dir,
                                            expected_graph=g,
                                            expected_summaries={})
            self.assertTrue(summary_writer.summaries,
                            'No summaries were created.')
            self.assertItemsEqual([2], summary_writer.summaries.keys())
            summary_value = summary_writer.summaries[2][0].value[0]
            self.assertEqual('bar/foo/sec', summary_value.tag)
 def DISABLED_test_save_steps_saves_periodically(self):
   with self.graph.as_default():
     hook = basic_session_run_hooks.CheckpointSaverHook(
         self.model_dir, save_steps=2, scaffold=self.scaffold)
     hook.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       mon_sess = monitored_session._HookedSession(sess, [hook])
       mon_sess.run(self.train_op)
       mon_sess.run(self.train_op)
       # Not saved
       self.assertEqual(1,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       mon_sess.run(self.train_op)
       # saved
       self.assertEqual(3,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       mon_sess.run(self.train_op)
       # Not saved
       self.assertEqual(3,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       mon_sess.run(self.train_op)
       # saved
       self.assertEqual(5,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
  def test_stop_based_on_num_step(self):
    h = basic_session_run_hooks.StopAtStepHook(num_steps=10)

    with ops.Graph().as_default():
      global_step = variables.get_or_create_global_step()
      no_op = control_flow_ops.no_op()
      h.begin()
      with session_lib.Session() as sess:
        mon_sess = monitored_session._HookedSession(sess, [h])
        sess.run(state_ops.assign(global_step, 5))
        h.after_create_session(sess, None)
        mon_sess.run(no_op)
        self.assertFalse(mon_sess.should_stop())
        sess.run(state_ops.assign(global_step, 13))
        mon_sess.run(no_op)
        self.assertFalse(mon_sess.should_stop())
        sess.run(state_ops.assign(global_step, 14))
        mon_sess.run(no_op)
        self.assertFalse(mon_sess.should_stop())
        sess.run(state_ops.assign(global_step, 15))
        mon_sess.run(no_op)
        self.assertTrue(mon_sess.should_stop())
        sess.run(state_ops.assign(global_step, 16))
        mon_sess._should_stop = False
        mon_sess.run(no_op)
        self.assertTrue(mon_sess.should_stop())
  def _validate_print_every_n_secs(self, sess, at_end):
    t = constant_op.constant(42.0, name='foo')
    train_op = constant_op.constant(3)

    hook = basic_session_run_hooks.LoggingTensorHook(
        tensors=[t.name], every_n_secs=1.0, at_end=at_end)
    hook.begin()
    mon_sess = monitored_session._HookedSession(sess, [hook])
    sess.run(variables_lib.global_variables_initializer())

    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self.logged_message), t.name)

    # assertNotRegexpMatches is not supported by python 3.1 and later
    self.logged_message = ''
    mon_sess.run(train_op)
    self.assertEqual(str(self.logged_message).find(t.name), -1)
    time.sleep(1.0)

    self.logged_message = ''
    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self.logged_message), t.name)

    self.logged_message = ''
    hook.end(sess)
    if at_end:
      self.assertRegexpMatches(str(self.logged_message), t.name)
    else:
      # assertNotRegexpMatches is not supported by python 3.1 and later
      self.assertEqual(str(self.logged_message).find(t.name), -1)
    def test_save_secs_saving_once_every_three_steps(self):
        hook = basic_session_run_hooks.SummarySaverHook(
            save_secs=0.9,
            summary_writer=self.summary_writer,
            summary_op=self.summary_op)

        with self.test_session() as sess:
            hook.begin()
            sess.run(variables_lib.global_variables_initializer())
            mon_sess = monitored_session._HookedSession(sess, [hook])
            for _ in range(8):
                mon_sess.run(self.train_op)
                time.sleep(0.3)
            hook.end(sess)

        self.summary_writer.assert_summaries(test_case=self,
                                             expected_logdir=self.log_dir,
                                             expected_summaries={
                                                 1: {
                                                     'my_summary': 1.0
                                                 },
                                                 4: {
                                                     'my_summary': 2.0
                                                 },
                                                 7: {
                                                     'my_summary': 3.0
                                                 },
                                             })
예제 #46
0
 def test_feeding_placeholder(self):
     with ops.Graph().as_default(), session_lib.Session() as sess:
         x = array_ops.placeholder(dtype=dtypes.float32)
         y = x + 1
         hook = basic_session_run_hooks.FeedFnHook(feed_fn=lambda: {x: 1.0})
         hook.begin()
         mon_sess = monitored_session._HookedSession(sess, [hook])
         self.assertEqual(mon_sess.run(y), 2)
예제 #47
0
 def testHookNotExceedingLimit(self):
   def _watch_fn(fetches, feeds):
     del fetches, feeds
     return "DebugIdentity", r".*delta.*", r".*"
   dumping_hook = hooks.DumpingDebugHook(
       self.session_root, watch_fn=_watch_fn, log_usage=False)
   mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
   mon_sess.run(self.inc_v)
예제 #48
0
 def testHookNotExceedingLimit(self):
   def _watch_fn(fetches, feeds):
     del fetches, feeds
     return "DebugIdentity", r".*delta.*", r".*"
   dumping_hook = hooks.DumpingDebugHook(
       self.session_root, watch_fn=_watch_fn, log_usage=False)
   mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
   mon_sess.run(self.inc_v)
 def test_feeding_placeholder(self):
   with ops.Graph().as_default(), session_lib.Session() as sess:
     x = array_ops.placeholder(dtype=dtypes.float32)
     y = x + 1
     hook = basic_session_run_hooks.FeedFnHook(
         feed_fn=lambda: {x: 1.0})
     hook.begin()
     mon_sess = monitored_session._HookedSession(sess, [hook])
     self.assertEqual(mon_sess.run(y), 2)
예제 #50
0
    def test_sweeps(self):
        is_row_sweep_var = variables.Variable(True)
        is_sweep_done_var = variables.Variable(False)
        init_done = variables.Variable(False)
        row_prep_done = variables.Variable(False)
        col_prep_done = variables.Variable(False)
        row_train_done = variables.Variable(False)
        col_train_done = variables.Variable(False)

        init_op = state_ops.assign(init_done, True)
        row_prep_op = state_ops.assign(row_prep_done, True)
        col_prep_op = state_ops.assign(col_prep_done, True)
        row_train_op = state_ops.assign(row_train_done, True)
        col_train_op = state_ops.assign(col_train_done, True)
        train_op = control_flow_ops.no_op()
        switch_op = control_flow_ops.group(
            state_ops.assign(is_sweep_done_var, False),
            state_ops.assign(is_row_sweep_var,
                             math_ops.logical_not(is_row_sweep_var)))
        mark_sweep_done = state_ops.assign(is_sweep_done_var, True)

        with self.test_session() as sess:
            sweep_hook = wals_lib._SweepHook(is_row_sweep_var,
                                             is_sweep_done_var, init_op,
                                             [row_prep_op], [col_prep_op],
                                             row_train_op, col_train_op,
                                             switch_op)
            mon_sess = monitored_session._HookedSession(sess, [sweep_hook])
            sess.run([variables.global_variables_initializer()])

            # Row sweep.
            mon_sess.run(train_op)
            self.assertTrue(sess.run(init_done),
                            msg='init op not run by the Sweephook')
            self.assertTrue(sess.run(row_prep_done),
                            msg='row_prep_op not run by the SweepHook')
            self.assertTrue(sess.run(row_train_done),
                            msg='row_train_op not run by the SweepHook')
            self.assertTrue(
                sess.run(is_row_sweep_var),
                msg='Row sweep is not complete but is_row_sweep_var is False.')
            # Col sweep.
            mon_sess.run(mark_sweep_done)
            mon_sess.run(train_op)
            self.assertTrue(sess.run(col_prep_done),
                            msg='col_prep_op not run by the SweepHook')
            self.assertTrue(sess.run(col_train_done),
                            msg='col_train_op not run by the SweepHook')
            self.assertFalse(
                sess.run(is_row_sweep_var),
                msg='Col sweep is not complete but is_row_sweep_var is True.')
            # Row sweep.
            mon_sess.run(mark_sweep_done)
            mon_sess.run(train_op)
            self.assertTrue(
                sess.run(is_row_sweep_var),
                msg='Col sweep is complete but is_row_sweep_var is False.')
    def test_sweeps(self):
        def ind_feed(row_indices, col_indices):
            return {
                self._input_row_indices_ph: row_indices,
                self._input_col_indices_ph: col_indices
            }

        with self.test_session() as sess:
            is_row_sweep_var = variables.Variable(True)
            sweep_hook = wals_lib._SweepHook(
                is_row_sweep_var, self._train_op, self._num_rows,
                self._num_cols, self._input_row_indices_ph,
                self._input_col_indices_ph, self._row_prep_ops,
                self._col_prep_ops, self._init_ops)
            mon_sess = monitored_session._HookedSession(sess, [sweep_hook])
            sess.run([variables.global_variables_initializer()])

            # Init ops should run before the first run. Row sweep not completed.
            mon_sess.run(self._train_op, ind_feed([0, 1, 2], []))
            self.assertTrue(sess.run(self._init_done),
                            msg='init ops not run by the sweep_hook')
            self.assertTrue(sess.run(self._row_prep_done),
                            msg='row_prep not run by the sweep_hook')
            self.assertTrue(
                sess.run(is_row_sweep_var),
                msg='Row sweep is not complete but is_row_sweep is '
                'False.')
            # Row sweep completed.
            mon_sess.run(self._train_op, ind_feed([3, 4],
                                                  [0, 1, 2, 3, 4, 5, 6]))
            self.assertFalse(
                sess.run(is_row_sweep_var),
                msg='Row sweep is complete but is_row_sweep is True.')
            self.assertTrue(
                sweep_hook._is_sweep_done,
                msg='Sweep is complete but is_sweep_done is False.')
            # Col init ops should run. Col sweep not completed.
            mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4]))
            self.assertTrue(sess.run(self._col_prep_done),
                            msg='col_prep not run by the sweep_hook')
            self.assertFalse(
                sess.run(is_row_sweep_var),
                msg='Col sweep is not complete but is_row_sweep is '
                'True.')
            self.assertFalse(
                sweep_hook._is_sweep_done,
                msg='Sweep is not complete but is_sweep_done is True.')
            # Col sweep completed.
            mon_sess.run(self._train_op, ind_feed([], [4, 5, 6]))
            self.assertTrue(
                sess.run(is_row_sweep_var),
                msg='Col sweep is complete but is_row_sweep is False')
            self.assertTrue(
                sweep_hook._is_sweep_done,
                msg='Sweep is complete but is_sweep_done is False.')
예제 #52
0
  def test_sweeps(self):
    def ind_feed(row_indices, col_indices):
      return {
          self._input_row_indices_ph: row_indices,
          self._input_col_indices_ph: col_indices
      }

    with self.test_session() as sess:
      is_row_sweep_var = variables.Variable(True)
      completed_sweeps_var = variables.Variable(0)
      sweep_hook = wals_lib._SweepHook(
          is_row_sweep_var,
          [self._train_op],
          self._num_rows,
          self._num_cols,
          self._input_row_indices_ph,
          self._input_col_indices_ph,
          self._row_prep_ops,
          self._col_prep_ops,
          self._init_ops,
          completed_sweeps_var)
      mon_sess = monitored_session._HookedSession(sess, [sweep_hook])
      sess.run([variables.global_variables_initializer()])

      # Init ops should run before the first run. Row sweep not completed.
      mon_sess.run(self._train_op, ind_feed([0, 1, 2], []))
      self.assertTrue(sess.run(self._init_done),
                      msg='init ops not run by the sweep_hook')
      self.assertTrue(sess.run(self._row_prep_done),
                      msg='row_prep not run by the sweep_hook')
      self.assertTrue(sess.run(is_row_sweep_var),
                      msg='Row sweep is not complete but is_row_sweep is '
                      'False.')
      # Row sweep completed.
      mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6]))
      self.assertTrue(sess.run(completed_sweeps_var) == 1,
                      msg='Completed sweeps should be equal to 1.')
      self.assertTrue(sess.run(sweep_hook._is_sweep_done_var),
                      msg='Sweep is complete but is_sweep_done is False.')
      # Col init ops should run. Col sweep not completed.
      mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4]))
      self.assertTrue(sess.run(self._col_prep_done),
                      msg='col_prep not run by the sweep_hook')
      self.assertFalse(sess.run(is_row_sweep_var),
                       msg='Col sweep is not complete but is_row_sweep is '
                       'True.')
      self.assertFalse(sess.run(sweep_hook._is_sweep_done_var),
                       msg='Sweep is not complete but is_sweep_done is True.')
      # Col sweep completed.
      mon_sess.run(self._train_op, ind_feed([], [4, 5, 6]))
      self.assertTrue(sess.run(sweep_hook._is_sweep_done_var),
                      msg='Sweep is complete but is_sweep_done is False.')
      self.assertTrue(sess.run(completed_sweeps_var) == 2,
                      msg='Completed sweeps should be equal to 2.')
예제 #53
0
 def test_print_formatter(self):
   with ops.Graph().as_default(), session_lib.Session() as sess:
     t = constant_op.constant(42.0, name='foo')
     train_op = constant_op.constant(3)
     hook = basic_session_run_hooks.LoggingTensorHook(
         tensors=[t.name], every_n_iter=10,
         formatter=lambda items: 'qqq=%s' % items[t.name])
     hook.begin()
     mon_sess = monitored_session._HookedSession(sess, [hook])
     sess.run(variables_lib.global_variables_initializer())
     mon_sess.run(train_op)
     self.assertEqual(self.logged_message[0], 'qqq=42.0')
 def test_print_formatter(self):
   with ops.Graph().as_default(), session_lib.Session() as sess:
     t = constant_op.constant(42.0, name='foo')
     train_op = constant_op.constant(3)
     hook = basic_session_run_hooks.LoggingTensorHook(
         tensors=[t.name], every_n_iter=10,
         formatter=lambda items: 'qqq=%s' % items[t.name])
     hook.begin()
     mon_sess = monitored_session._HookedSession(sess, [hook])
     sess.run(variables_lib.global_variables_initializer())
     mon_sess.run(train_op)
     self.assertEqual(self.logged_message[0], 'qqq=42.0')
예제 #55
0
 def test_saves_when_saver_and_scaffold_both_missing(self):
   with self.graph.as_default():
     hook = basic_session_run_hooks.CheckpointSaverHook(
         self.model_dir, save_steps=1)
     hook.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       mon_sess = monitored_session._HookedSession(sess, [hook])
       mon_sess.run(self.train_op)
       self.assertEqual(1,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
 def test_saves_when_saver_and_scaffold_both_missing(self):
   with self.graph.as_default():
     hook = basic_session_run_hooks.CheckpointSaverHook(
         self.model_dir, save_steps=1)
     hook.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       mon_sess = monitored_session._HookedSession(sess, [hook])
       mon_sess.run(self.train_op)
       self.assertEqual(1,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
예제 #57
0
  def testDumpingDebugHookWithoutWatchFnWorks(self):
    dumping_hook = hooks.DumpingDebugHook(self.session_root, log_usage=False)
    mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
    mon_sess.run(self.inc_v)

    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
    self.assertEqual(1, len(dump_dirs))

    self._assert_correct_run_subdir_naming(os.path.basename(dump_dirs[0]))
    dump = debug_data.DebugDumpDir(dump_dirs[0])
    self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity"))

    self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
    self.assertEqual(repr(None), dump.run_feed_keys_info)
 def test_print_first_step(self):
   # if it runs every iteration, first iteration has None duration.
   with ops.Graph().as_default(), session_lib.Session() as sess:
     t = constant_op.constant(42.0, name='foo')
     train_op = constant_op.constant(3)
     hook = basic_session_run_hooks.LoggingTensorHook(
         tensors={'foo': t}, every_n_iter=1)
     hook.begin()
     mon_sess = monitored_session._HookedSession(sess, [hook])
     sess.run(variables_lib.global_variables_initializer())
     mon_sess.run(train_op)
     self.assertRegexpMatches(str(self.logged_message), 'foo')
     # in first run, elapsed time is None.
     self.assertEqual(str(self.logged_message).find('sec'), -1)