def test_step_counter_every_n_secs(self): with ops.Graph().as_default() as g, session_lib.Session() as sess: variables.get_or_create_global_step() train_op = training_util._increment_global_step(1) summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g) hook = basic_session_run_hooks.StepCounterHook( summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1) hook.begin() sess.run(variables_lib.global_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(train_op) time.sleep(0.2) mon_sess.run(train_op) time.sleep(0.2) mon_sess.run(train_op) hook.end(sess) summary_writer.assert_summaries( test_case=self, expected_logdir=self.log_dir, expected_graph=g, expected_summaries={}) self.assertTrue(summary_writer.summaries, 'No summaries were created.') self.assertItemsEqual([2, 3], summary_writer.summaries.keys()) for summary in summary_writer.summaries.values(): summary_value = summary[0].value[0] self.assertEqual('global_step/sec', summary_value.tag) self.assertGreater(summary_value.simple_value, 0)
def _validate_print_every_n_steps(self, sess, at_end): t = tf.constant(42.0, name="foo") train_op = tf.constant(3) hook = metric_hook.LoggingMetricHook( tensors=[t.name], every_n_iter=10, at_end=at_end, metric_logger=self._logger) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access sess.run(tf.compat.v1.global_variables_initializer()) mon_sess.run(train_op) self.assertRegexpMatches(str(self._logger.logged_metric), t.name) for _ in range(3): self._logger.logged_metric = [] for _ in range(9): mon_sess.run(train_op) # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self._logger.logged_metric).find(t.name), -1) mon_sess.run(train_op) self.assertRegexpMatches(str(self._logger.logged_metric), t.name) # Add additional run to verify proper reset when called multiple times. self._logger.logged_metric = [] mon_sess.run(train_op) # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self._logger.logged_metric).find(t.name), -1) self._logger.logged_metric = [] hook.end(sess) if at_end: self.assertRegexpMatches(str(self._logger.logged_metric), t.name) else: # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
def test_step_counter_every_n_steps(self): with ops.Graph().as_default() as g, session_lib.Session() as sess: variables.get_or_create_global_step() train_op = training_util._increment_global_step(1) summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g) hook = basic_session_run_hooks.StepCounterHook( summary_writer=summary_writer, every_n_steps=10) hook.begin() sess.run(variables_lib.global_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) with test.mock.patch.object(tf_logging, 'warning') as mock_log: for _ in range(30): time.sleep(0.01) mon_sess.run(train_op) # logging.warning should not be called. self.assertIsNone(mock_log.call_args) hook.end(sess) summary_writer.assert_summaries( test_case=self, expected_logdir=self.log_dir, expected_graph=g, expected_summaries={}) self.assertItemsEqual([11, 21], summary_writer.summaries.keys()) for step in [11, 21]: summary_value = summary_writer.summaries[step][0].value[0] self.assertEqual('global_step/sec', summary_value.tag) self.assertGreater(summary_value.simple_value, 0)
def test_save_secs_calls_listeners_periodically(self): with self.graph.as_default(): listener = MockCheckpointSaverListener() hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_secs=2, scaffold=self.scaffold, listeners=[listener]) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(self.train_op) # hook runs here mon_sess.run(self.train_op) time.sleep(2.5) mon_sess.run(self.train_op) # hook runs here mon_sess.run(self.train_op) mon_sess.run(self.train_op) time.sleep(2.5) mon_sess.run(self.train_op) # hook runs here mon_sess.run( self.train_op) # hook won't run here, so it does at end hook.end(sess) # hook runs here self.assertEqual( { 'begin': 1, 'before_save': 4, 'after_save': 4, 'end': 1 }, listener.get_counts())
def test_save_steps_saves_periodically(self): with self.graph.as_default(): hook = tf.train.CheckpointSaverHook( self.model_dir, save_steps=2, scaffold=self.scaffold) hook.begin() self.scaffold.finalize() with tf.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(self.train_op) mon_sess.run(self.train_op) # Not saved self.assertEqual(1, tf.contrib.framework.load_variable( self.model_dir, self.global_step.name)) mon_sess.run(self.train_op) # saved self.assertEqual(3, tf.contrib.framework.load_variable( self.model_dir, self.global_step.name)) mon_sess.run(self.train_op) # Not saved self.assertEqual(3, tf.contrib.framework.load_variable( self.model_dir, self.global_step.name)) mon_sess.run(self.train_op) # saved self.assertEqual(5, tf.contrib.framework.load_variable( self.model_dir, self.global_step.name))
def _validate_log_every_n_steps(self, sess, every_n_steps, warm_steps): hook = hooks.ExamplesPerSecondHook( batch_size=256, every_n_steps=every_n_steps, warm_steps=warm_steps) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access sess.run(tf.global_variables_initializer()) self.logged_message = '' for _ in range(every_n_steps): mon_sess.run(self.train_op) self.assertEqual(str(self.logged_message).find('exp/sec'), -1) mon_sess.run(self.train_op) global_step_val = sess.run(self.global_step) # assertNotRegexpMatches is not supported by python 3.1 and later if global_step_val > warm_steps: self.assertRegexpMatches(str(self.logged_message), 'exp/sec') else: self.assertEqual(str(self.logged_message).find('exp/sec'), -1) # Add additional run to verify proper reset when called multiple times. self.logged_message = '' mon_sess.run(self.train_op) global_step_val = sess.run(self.global_step) if every_n_steps == 1 and global_step_val > warm_steps: self.assertRegexpMatches(str(self.logged_message), 'exp/sec') else: self.assertEqual(str(self.logged_message).find('exp/sec'), -1) hook.end(sess)
def test_save_secs_saving_once_every_three_steps(self, mock_time): mock_time.return_value = 1484695987.209386 hook = basic_session_run_hooks.SummarySaverHook( save_secs=9., summary_writer=self.summary_writer, summary_op=self.summary_op) with self.test_session() as sess: hook.begin() sess.run(variables_lib.global_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) for _ in range(8): mon_sess.run(self.train_op) mock_time.return_value += 3.1 hook.end(sess) # 24.8 seconds passed (3.1*8), it saves every 9 seconds starting from first: self.summary_writer.assert_summaries( test_case=self, expected_logdir=self.log_dir, expected_summaries={ 1: { 'my_summary': 1.0 }, 4: { 'my_summary': 2.0 }, 7: { 'my_summary': 3.0 }, })
def testDumpingDebugHookWithStatefulLegacyWatchFnWorks(self): watch_fn_state = {"run_counter": 0} def counting_watch_fn(fetches, feed_dict): del fetches, feed_dict watch_fn_state["run_counter"] += 1 if watch_fn_state["run_counter"] % 2 == 1: # If odd-index run (1-based), watch everything. return "DebugIdentity", r".*", r".*" else: # If even-index run, watch nothing. return "DebugIdentity", r"$^", r"$^" dumping_hook = hooks.DumpingDebugHook( self.session_root, watch_fn=counting_watch_fn, log_usage=False) mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook]) for _ in range(4): mon_sess.run(self.inc_v) dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) dump_dirs = sorted( dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1])) self.assertEqual(4, len(dump_dirs)) for i, dump_dir in enumerate(dump_dirs): self._assert_correct_run_subdir_naming(os.path.basename(dump_dir)) dump = debug_data.DebugDumpDir(dump_dir) if i % 2 == 0: self.assertAllClose([10.0 + 1.0 * i], dump.get_tensors("v", 0, "DebugIdentity")) else: self.assertEqual(0, dump.size) self.assertEqual(repr(self.inc_v), dump.run_fetches_info) self.assertEqual(repr(None), dump.run_feed_keys_info)
def test_summary_writer_defs(self): testing.FakeSummaryWriter.install() tf.train.SummaryWriterCache.clear() summary_writer = tf.train.SummaryWriterCache.get(self.model_dir) with self.graph.as_default(): hook = tf.train.CheckpointSaverHook(self.model_dir, save_steps=2, scaffold=self.scaffold) hook.begin() self.scaffold.finalize() with tf.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(self.train_op) summary_writer.assert_summaries( test_case=self, expected_logdir=self.model_dir, expected_added_meta_graphs=[ meta_graph.create_meta_graph_def( graph_def=self.graph.as_graph_def(add_shapes=True), saver_def=self.scaffold.saver.saver_def) ]) testing.FakeSummaryWriter.uninstall()
def _validate_print_every_n_secs(self, sess, at_end): t = tf.constant(42.0, name='foo') train_op = tf.constant(3) hook = metric_hook.LoggingMetricHook( tensors=[t.name], every_n_secs=1.0, at_end=at_end, metric_logger=self._logger) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) sess.run(tf.global_variables_initializer()) mon_sess.run(train_op) self.assertRegexpMatches(str(self._logger.logged_metric), t.name) # assertNotRegexpMatches is not supported by python 3.1 and later self._logger.logged_metric = [] mon_sess.run(train_op) self.assertEqual(str(self._logger.logged_metric).find(t.name), -1) time.sleep(1.0) self._logger.logged_metric = [] mon_sess.run(train_op) self.assertRegexpMatches(str(self._logger.logged_metric), t.name) self._logger.logged_metric = [] hook.end(sess) if at_end: self.assertRegexpMatches(str(self._logger.logged_metric), t.name) else: # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
def test_log_tensors(self): with tf.Graph().as_default(), tf.Session() as sess: tf.train.get_or_create_global_step() t1 = tf.constant(42.0, name='foo') t2 = tf.constant(43.0, name='bar') train_op = tf.constant(3) hook = metric_hook.LoggingMetricHook( tensors=[t1, t2], at_end=True, metric_logger=self._logger) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) sess.run(tf.global_variables_initializer()) for _ in range(3): mon_sess.run(train_op) self.assertEqual(self._logger.logged_metric, []) hook.end(sess) self.assertEqual(len(self._logger.logged_metric), 2) metric1 = self._logger.logged_metric[0] self.assertRegexpMatches(str(metric1["name"]), "foo") self.assertEqual(metric1["value"], 42.0) self.assertEqual(metric1["unit"], None) self.assertEqual(metric1["global_step"], 0) metric2 = self._logger.logged_metric[1] self.assertRegexpMatches(str(metric2["name"]), "bar") self.assertEqual(metric2["value"], 43.0) self.assertEqual(metric2["unit"], None) self.assertEqual(metric2["global_step"], 0)
def testBothHooksAndUserHaveFeeds(self): with tf.Graph().as_default(), tf.Session() as sess: mock_hook = FakeHook() mock_hook2 = FakeHook() mon_sess = monitored_session._HookedSession( sess=sess, hooks=[mock_hook, mock_hook2]) a_tensor = tf.constant([0], name='a_tensor') b_tensor = tf.constant([0], name='b_tensor') c_tensor = tf.constant([0], name='c_tensor') add_tensor = a_tensor + b_tensor + c_tensor mock_hook.request = tf.train.SessionRunArgs( None, feed_dict={ a_tensor: [5] }) mock_hook2.request = tf.train.SessionRunArgs( None, feed_dict={ b_tensor: [10] }) sess.run(tf.global_variables_initializer()) feed_dict = {c_tensor: [20]} self.assertEqual( mon_sess.run(fetches=add_tensor, feed_dict=feed_dict), [35]) # User feed_dict should not be changed self.assertEqual(len(feed_dict), 1)
def test_multiple_summaries(self): hook = basic_session_run_hooks.SummarySaverHook( save_steps=8, summary_writer=self.summary_writer, summary_op=[self.summary_op, self.summary_op2]) with self.test_session() as sess: hook.begin() sess.run(variables_lib.global_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) for _ in range(10): mon_sess.run(self.train_op) hook.end(sess) self.summary_writer.assert_summaries(test_case=self, expected_logdir=self.log_dir, expected_summaries={ 1: { 'my_summary': 1.0, 'my_summary2': 2.0 }, 9: { 'my_summary': 2.0, 'my_summary2': 4.0 }, })
def test_capture(self): global_step = tf.contrib.framework.get_or_create_global_step() # Some test computation some_weights = tf.get_variable("weigths", [2, 128]) computation = tf.nn.softmax(some_weights) hook = hooks.MetadataCaptureHook( params={"step": 5}, model_dir=self.model_dir, run_config=tf.contrib.learn.RunConfig()) hook.begin() with self.test_session() as sess: sess.run(tf.global_variables_initializer()) #pylint: disable=W0212 mon_sess = monitored_session._HookedSession(sess, [hook]) # Should not trigger for step 0 sess.run(tf.assign(global_step, 0)) mon_sess.run(computation) self.assertEqual(gfile.ListDirectory(self.model_dir), []) # Should trigger *after* step 5 sess.run(tf.assign(global_step, 5)) mon_sess.run(computation) self.assertEqual(gfile.ListDirectory(self.model_dir), []) mon_sess.run(computation) self.assertEqual( set(gfile.ListDirectory(self.model_dir)), set(["run_meta", "tfprof_log", "timeline.json"]))
def test_step_counter_every_n_secs(self): with tf.Graph().as_default() as g, tf.Session() as sess: global_step = tf.contrib.framework.get_or_create_global_step() train_op = tf.assign_add(global_step, 1) summary_writer = testing.FakeSummaryWriter(self.log_dir, g) hook = tf.train.StepCounterHook(summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1) hook.begin() sess.run(tf.initialize_all_variables()) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(train_op) time.sleep(0.2) mon_sess.run(train_op) time.sleep(0.2) mon_sess.run(train_op) hook.end(sess) summary_writer.assert_summaries( test_case=self, expected_logdir=self.log_dir, expected_graph=g, expected_summaries={} ) self.assertTrue(summary_writer.summaries, "No summaries were created.") self.assertItemsEqual([2, 3], summary_writer.summaries.keys()) for summary in summary_writer.summaries.values(): summary_value = summary[0].value[0] self.assertEqual("global_step/sec", summary_value.tag) self.assertGreater(summary_value.simple_value, 0)
def _validate_print_every_n_steps(self, sess, at_end): t = constant_op.constant(42.0, name='foo') train_op = constant_op.constant(3) hook = basic_session_run_hooks.LoggingTensorHook( tensors=[t.name], every_n_iter=10, at_end=at_end) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) sess.run(variables_lib.global_variables_initializer()) mon_sess.run(train_op) self.assertRegexpMatches(str(self.logged_message), t.name) for _ in range(3): self.logged_message = '' for _ in range(9): mon_sess.run(train_op) # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self.logged_message).find(t.name), -1) mon_sess.run(train_op) self.assertRegexpMatches(str(self.logged_message), t.name) # Add additional run to verify proper reset when called multiple times. self.logged_message = '' mon_sess.run(train_op) # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self.logged_message).find(t.name), -1) self.logged_message = '' hook.end(sess) if at_end: self.assertRegexpMatches(str(self.logged_message), t.name) else: # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self.logged_message).find(t.name), -1)
def test_summary_saver(self): with tf.Graph().as_default() as g, tf.Session() as sess: log_dir = 'log/dir' summary_writer = testing.FakeSummaryWriter(log_dir, g) var = tf.Variable(0.0) tensor = tf.assign_add(var, 1.0) summary_op = tf.scalar_summary('my_summary', tensor) global_step = tf.contrib.framework.get_or_create_global_step() train_op = tf.assign_add(global_step, 1) hook = tf.train.SummarySaverHook( summary_op=summary_op, save_steps=8, summary_writer=summary_writer) hook.begin() sess.run(tf.initialize_all_variables()) mon_sess = monitored_session._HookedSession(sess, [hook]) for i in range(30): _ = i mon_sess.run(train_op) hook.end(sess) summary_writer.assert_summaries( test_case=self, expected_logdir=log_dir, expected_graph=g, expected_summaries={ 1: {'my_summary': 1.0}, 9: {'my_summary': 2.0}, 17: {'my_summary': 3.0}, 25: {'my_summary': 4.0}, })
def DISABLED_test_save_secs_calls_listeners_periodically(self): with self.graph.as_default(): listener = MockCheckpointSaverListener() hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_secs=2, scaffold=self.scaffold, listeners=[listener]) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(self.train_op) # hook runs here mon_sess.run(self.train_op) time.sleep(2.5) mon_sess.run(self.train_op) # hook runs here mon_sess.run(self.train_op) mon_sess.run(self.train_op) time.sleep(2.5) mon_sess.run(self.train_op) # hook runs here mon_sess.run(self.train_op) # hook won't run here, so it does at end hook.end(sess) # hook runs here self.assertEqual({ 'begin': 1, 'before_save': 4, 'after_save': 4, 'end': 1 }, listener.get_counts())
def testBothHooksAndUserHaveFeeds(self): with tf.Graph().as_default(), tf.Session() as sess: mock_hook = FakeHook() mock_hook2 = FakeHook() mon_sess = monitored_session._HookedSession( sess=sess, hooks=[mock_hook, mock_hook2]) a_tensor = tf.constant([0], name='a_tensor') b_tensor = tf.constant([0], name='b_tensor') c_tensor = tf.constant([0], name='c_tensor') add_tensor = a_tensor + b_tensor + c_tensor mock_hook.request = tf.train.SessionRunArgs( None, feed_dict={ a_tensor: [5] }) mock_hook2.request = tf.train.SessionRunArgs( None, feed_dict={ b_tensor: [10] }) sess.run(tf.initialize_all_variables()) feed_dict = {c_tensor: [20]} self.assertEqual( mon_sess.run(fetches=add_tensor, feed_dict=feed_dict), [35]) # User feed_dict should not be changed self.assertEqual(len(feed_dict), 1)
def test_step_counter_every_n_secs(self): with ops.Graph().as_default() as g, session_lib.Session() as sess: global_step = variables.get_or_create_global_step() train_op = state_ops.assign_add(global_step, 1) summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g) hook = basic_session_run_hooks.StepCounterHook( summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1) hook.begin() sess.run(variables_lib.global_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(train_op) time.sleep(0.2) mon_sess.run(train_op) time.sleep(0.2) mon_sess.run(train_op) hook.end(sess) summary_writer.assert_summaries( test_case=self, expected_logdir=self.log_dir, expected_graph=g, expected_summaries={}) self.assertTrue(summary_writer.summaries, 'No summaries were created.') self.assertItemsEqual([2, 3], summary_writer.summaries.keys()) for summary in summary_writer.summaries.values(): summary_value = summary[0].value[0] self.assertEqual('global_step/sec', summary_value.tag) self.assertGreater(summary_value.simple_value, 0)
def test_save_secs_saving_once_every_three_steps(self): hook = basic_session_run_hooks.SummarySaverHook( save_secs=0.9, summary_writer=self.summary_writer, summary_op=self.summary_op) with self.test_session() as sess: hook.begin() sess.run(variables_lib.global_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) for _ in range(8): mon_sess.run(self.train_op) time.sleep(0.3) hook.end(sess) self.summary_writer.assert_summaries( test_case=self, expected_logdir=self.log_dir, expected_summaries={ 1: { 'my_summary': 1.0 }, 4: { 'my_summary': 2.0 }, 7: { 'my_summary': 3.0 }, })
def test_multiple_summaries(self): hook = basic_session_run_hooks.SummarySaverHook( save_steps=8, summary_writer=self.summary_writer, summary_op=[self.summary_op, self.summary_op2]) with self.test_session() as sess: hook.begin() sess.run(variables_lib.global_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) for _ in range(10): mon_sess.run(self.train_op) hook.end(sess) self.summary_writer.assert_summaries( test_case=self, expected_logdir=self.log_dir, expected_summaries={ 1: { 'my_summary': 1.0, 'my_summary2': 2.0 }, 9: { 'my_summary': 2.0, 'my_summary2': 4.0 }, })
def test_log_tensors(self): with tf.Graph().as_default(), tf.Session() as sess: tf.train.get_or_create_global_step() t1 = tf.constant(42.0, name='foo') t2 = tf.constant(43.0, name='bar') train_op = tf.constant(3) hook = metric_hook.LoggingMetricHook(tensors=[t1, t2], at_end=True, metric_logger=self._logger) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) sess.run(tf.global_variables_initializer()) for _ in range(3): mon_sess.run(train_op) self.assertEqual(self._logger.logged_metric, []) hook.end(sess) self.assertEqual(len(self._logger.logged_metric), 2) metric1 = self._logger.logged_metric[0] self.assertRegexpMatches(str(metric1["name"]), "foo") self.assertEqual(metric1["value"], 42.0) self.assertEqual(metric1["unit"], None) self.assertEqual(metric1["global_step"], 0) metric2 = self._logger.logged_metric[1] self.assertRegexpMatches(str(metric2["name"]), "bar") self.assertEqual(metric2["value"], 43.0) self.assertEqual(metric2["unit"], None) self.assertEqual(metric2["global_step"], 0)
def _validate_print_every_n_secs(self, sess, at_end): t = tf.constant(42.0, name='foo') train_op = tf.constant(3) hook = metric_hook.LoggingMetricHook(tensors=[t.name], every_n_secs=1.0, at_end=at_end, metric_logger=self._logger) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) sess.run(tf.global_variables_initializer()) mon_sess.run(train_op) self.assertRegexpMatches(str(self._logger.logged_metric), t.name) # assertNotRegexpMatches is not supported by python 3.1 and later self._logger.logged_metric = [] mon_sess.run(train_op) self.assertEqual(str(self._logger.logged_metric).find(t.name), -1) time.sleep(1.0) self._logger.logged_metric = [] mon_sess.run(train_op) self.assertRegexpMatches(str(self._logger.logged_metric), t.name) self._logger.logged_metric = [] hook.end(sess) if at_end: self.assertRegexpMatches(str(self._logger.logged_metric), t.name) else: # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
def _validate_print_every_n_secs(self, sess, at_end): t = constant_op.constant(42.0, name='foo') train_op = constant_op.constant(3) hook = basic_session_run_hooks.LoggingTensorHook(tensors=[t.name], every_n_secs=1.0, at_end=at_end) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) sess.run(variables_lib.global_variables_initializer()) mon_sess.run(train_op) self.assertRegexpMatches(str(self.logged_message), t.name) # assertNotRegexpMatches is not supported by python 3.1 and later self.logged_message = '' mon_sess.run(train_op) self.assertEqual(str(self.logged_message).find(t.name), -1) time.sleep(1.0) self.logged_message = '' mon_sess.run(train_op) self.assertRegexpMatches(str(self.logged_message), t.name) self.logged_message = '' hook.end(sess) if at_end: self.assertRegexpMatches(str(self.logged_message), t.name) else: # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self.logged_message).find(t.name), -1)
def test_save_steps(self): hook = tf.train.SummarySaverHook(save_steps=8, summary_writer=self.summary_writer, summary_op=self.summary_op) with self.test_session() as sess: hook.begin() sess.run(tf.initialize_all_variables()) mon_sess = monitored_session._HookedSession(sess, [hook]) for _ in range(30): mon_sess.run(self.train_op) hook.end(sess) self.summary_writer.assert_summaries(test_case=self, expected_logdir=self.log_dir, expected_summaries={ 1: { 'my_summary': 1.0 }, 9: { 'my_summary': 2.0 }, 17: { 'my_summary': 3.0 }, 25: { 'my_summary': 4.0 }, })
def _validate_print_every_n_steps(self, sess, at_end): t = constant_op.constant(42.0, name='foo') train_op = constant_op.constant(3) hook = basic_session_run_hooks.LoggingTensorHook(tensors=[t.name], every_n_iter=10, at_end=at_end) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) sess.run(variables_lib.global_variables_initializer()) mon_sess.run(train_op) self.assertRegexpMatches(str(self.logged_message), t.name) for _ in range(3): self.logged_message = '' for _ in range(9): mon_sess.run(train_op) # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self.logged_message).find(t.name), -1) mon_sess.run(train_op) self.assertRegexpMatches(str(self.logged_message), t.name) # Add additional run to verify proper reset when called multiple times. self.logged_message = '' mon_sess.run(train_op) # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self.logged_message).find(t.name), -1) self.logged_message = '' hook.end(sess) if at_end: self.assertRegexpMatches(str(self.logged_message), t.name) else: # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self.logged_message).find(t.name), -1)
def test_stop_based_on_num_step(self): h = basic_session_run_hooks.StopAtStepHook(num_steps=10) with ops.Graph().as_default(): global_step = variables.get_or_create_global_step() no_op = control_flow_ops.no_op() h.begin() with session_lib.Session() as sess: mon_sess = monitored_session._HookedSession(sess, [h]) sess.run(state_ops.assign(global_step, 5)) h.after_create_session(sess, None) mon_sess.run(no_op) self.assertFalse(mon_sess.should_stop()) sess.run(state_ops.assign(global_step, 13)) mon_sess.run(no_op) self.assertFalse(mon_sess.should_stop()) sess.run(state_ops.assign(global_step, 14)) mon_sess.run(no_op) self.assertFalse(mon_sess.should_stop()) sess.run(state_ops.assign(global_step, 15)) mon_sess.run(no_op) self.assertTrue(mon_sess.should_stop()) sess.run(state_ops.assign(global_step, 16)) mon_sess._should_stop = False mon_sess.run(no_op) self.assertTrue(mon_sess.should_stop())
def test_save_steps_saves_periodically(self): with self.graph.as_default(): hook = tf.train.CheckpointSaverHook(self.model_dir, save_steps=2, scaffold=self.scaffold) hook.begin() self.scaffold.finalize() with tf.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(self.train_op) mon_sess.run(self.train_op) # Not saved self.assertEqual( 1, tf.contrib.framework.load_variable(self.model_dir, self.global_step.name)) mon_sess.run(self.train_op) # saved self.assertEqual( 3, tf.contrib.framework.load_variable(self.model_dir, self.global_step.name)) mon_sess.run(self.train_op) # Not saved self.assertEqual( 3, tf.contrib.framework.load_variable(self.model_dir, self.global_step.name)) mon_sess.run(self.train_op) # saved self.assertEqual( 5, tf.contrib.framework.load_variable(self.model_dir, self.global_step.name))
def testTensorBoardDebugHookDisablingTracebackSourceCodeSendingWorks(self): u = variables.VariableV1(2.1, name="u") v = variables.VariableV1(20.0, name="v") w = math_ops.multiply(u, v, name="w") sess = session.Session( config=session_debug_testlib.no_rewrite_session_config()) sess.run(variables.global_variables_initializer()) grpc_debug_hook = hooks.TensorBoardDebugHook( ["localhost:%d" % self._server_port], send_traceback_and_source_code=False) sess = monitored_session._HookedSession(sess, [grpc_debug_hook]) # Activate watch point on a tensor before calling sess.run(). self._server.request_watch("u/read", 0, "DebugIdentity") self.assertAllClose(42.0, sess.run(w)) # Check that the server has _not_ received any tracebacks, as a result of # the disabling above. with self.assertRaisesRegex(ValueError, r"Op .*u/read.* does not exist"): self.assertTrue(self._server.query_op_traceback("u/read")) with self.assertRaisesRegex(ValueError, r".* has not received any source file"): self._server.query_source_file_line(__file__, 1)
def test_step_counter_every_n_secs(self): with tf.Graph().as_default() as g, tf.Session() as sess: global_step = tf.contrib.framework.get_or_create_global_step() train_op = tf.assign_add(global_step, 1) summary_writer = testing.FakeSummaryWriter(self.log_dir, g) hook = tf.train.StepCounterHook(summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1) hook.begin() sess.run(tf.global_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(train_op) time.sleep(0.2) mon_sess.run(train_op) time.sleep(0.2) mon_sess.run(train_op) hook.end(sess) summary_writer.assert_summaries(test_case=self, expected_logdir=self.log_dir, expected_graph=g, expected_summaries={}) self.assertTrue(summary_writer.summaries, 'No summaries were created.') self.assertItemsEqual([2, 3], summary_writer.summaries.keys()) for summary in summary_writer.summaries.values(): summary_value = summary[0].value[0] self.assertEqual('global_step/sec', summary_value.tag) self.assertGreater(summary_value.simple_value, 0)
def test_save_secs_saving_once_every_step(self): hook = tf.train.SummarySaverHook( save_steps=None, save_secs=0.5, summary_writer=self.summary_writer, summary_op=self.summary_op) with self.test_session() as sess: hook.begin() sess.run(tf.initialize_all_variables()) mon_sess = monitored_session._HookedSession(sess, [hook]) for _ in range(4): mon_sess.run(self.train_op) time.sleep(0.5) hook.end(sess) self.summary_writer.assert_summaries( test_case=self, expected_logdir=self.log_dir, expected_summaries={ 1: {'my_summary': 1.0}, 2: {'my_summary': 2.0}, 3: {'my_summary': 3.0}, 4: {'my_summary': 4.0}, })
def test_capture(self): global_step = tf.contrib.framework.get_or_create_global_step() # Some test computation some_weights = tf.get_variable("weigths", [2, 128]) computation = tf.nn.softmax(some_weights) hook = hooks.MetadataCaptureHook(params={"step": 5}, model_dir=self.model_dir) hook.begin() with self.test_session() as sess: sess.run(tf.global_variables_initializer()) #pylint: disable=W0212 mon_sess = monitored_session._HookedSession(sess, [hook]) # Should not trigger for step 0 sess.run(tf.assign(global_step, 0)) mon_sess.run(computation) self.assertEqual(gfile.ListDirectory(self.model_dir), []) # Should trigger *after* step 5 sess.run(tf.assign(global_step, 5)) mon_sess.run(computation) self.assertEqual(gfile.ListDirectory(self.model_dir), []) mon_sess.run(computation) self.assertEqual(set(gfile.ListDirectory(self.model_dir)), set(["run_meta", "tfprof_log", "timeline.json"]))
def test_save_secs_saves_periodically(self): with self.graph.as_default(): hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_secs=2, scaffold=self.scaffold) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(self.train_op) # Saved. mon_sess.run(self.train_op) # Not saved. self.assertEqual( 1, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) time.sleep(2.5) mon_sess.run(self.train_op) # Saved. mon_sess.run(self.train_op) # Not saved. mon_sess.run(self.train_op) # Not saved. self.assertEqual( 3, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) time.sleep(2.5) mon_sess.run(self.train_op) # Saved. self.assertEqual( 6, checkpoint_utils.load_variable(self.model_dir, self.global_step.name))
def test_global_step_name(self): with ops.Graph().as_default() as g, session_lib.Session() as sess: with variable_scope.variable_scope('bar'): foo_step = variable_scope.get_variable( 'foo', initializer=0, trainable=False, collections=[ ops.GraphKeys.GLOBAL_STEP, ops.GraphKeys.GLOBAL_VARIABLES ]) train_op = state_ops.assign_add(foo_step, 1) summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g) hook = basic_session_run_hooks.StepCounterHook( summary_writer=summary_writer, every_n_steps=1, every_n_secs=None) hook.begin() sess.run(variables_lib.global_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(train_op) mon_sess.run(train_op) hook.end(sess) summary_writer.assert_summaries( test_case=self, expected_logdir=self.log_dir, expected_graph=g, expected_summaries={}) self.assertTrue(summary_writer.summaries, 'No summaries were created.') self.assertItemsEqual([2], summary_writer.summaries.keys()) summary_value = summary_writer.summaries[2][0].value[0] self.assertEqual('bar/foo/sec', summary_value.tag)
def testTensorBoardDebugHookDisablingTracebackSourceCodeSendingWorks(self): u = variables.Variable(2.1, name="u") v = variables.Variable(20.0, name="v") w = math_ops.multiply(u, v, name="w") sess = session.Session(config=no_rewrite_session_config()) sess.run(variables.global_variables_initializer()) grpc_debug_hook = hooks.TensorBoardDebugHook( ["localhost:%d" % self._server_port], send_traceback_and_source_code=False) sess = monitored_session._HookedSession(sess, [grpc_debug_hook]) # Activate watch point on a tensor before calling sess.run(). self._server.request_watch("u/read", 0, "DebugIdentity") self.assertAllClose(42.0, sess.run(w)) # Check that the server has _not_ received any tracebacks, as a result of # the disabling above. with self.assertRaisesRegexp( ValueError, r"Op .*u/read.* does not exist"): self.assertTrue(self._server.query_op_traceback("u/read")) with self.assertRaisesRegexp( ValueError, r".* has not received any source file"): self._server.query_source_file_line(__file__, 1)
def test_global_step_name(self): with ops.Graph().as_default() as g, session_lib.Session() as sess: with variable_scope.variable_scope('bar'): foo_step = variable_scope.get_variable( 'foo', initializer=0, trainable=False, collections=[ ops.GraphKeys.GLOBAL_STEP, ops.GraphKeys.GLOBAL_VARIABLES ]) train_op = state_ops.assign_add(foo_step, 1) summary_writer = fake_summary_writer.FakeSummaryWriter( self.log_dir, g) hook = basic_session_run_hooks.StepCounterHook( summary_writer=summary_writer, every_n_steps=1, every_n_secs=None) hook.begin() sess.run(variables_lib.global_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(train_op) mon_sess.run(train_op) hook.end(sess) summary_writer.assert_summaries(test_case=self, expected_logdir=self.log_dir, expected_graph=g, expected_summaries={}) self.assertTrue(summary_writer.summaries, 'No summaries were created.') self.assertItemsEqual([2], summary_writer.summaries.keys()) summary_value = summary_writer.summaries[2][0].value[0] self.assertEqual('bar/foo/sec', summary_value.tag)
def DISABLED_test_save_steps_saves_periodically(self): with self.graph.as_default(): hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_steps=2, scaffold=self.scaffold) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(self.train_op) mon_sess.run(self.train_op) # Not saved self.assertEqual(1, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) mon_sess.run(self.train_op) # saved self.assertEqual(3, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) mon_sess.run(self.train_op) # Not saved self.assertEqual(3, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) mon_sess.run(self.train_op) # saved self.assertEqual(5, checkpoint_utils.load_variable(self.model_dir, self.global_step.name))
def _validate_print_every_n_secs(self, sess, at_end): t = constant_op.constant(42.0, name='foo') train_op = constant_op.constant(3) hook = basic_session_run_hooks.LoggingTensorHook( tensors=[t.name], every_n_secs=1.0, at_end=at_end) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) sess.run(variables_lib.global_variables_initializer()) mon_sess.run(train_op) self.assertRegexpMatches(str(self.logged_message), t.name) # assertNotRegexpMatches is not supported by python 3.1 and later self.logged_message = '' mon_sess.run(train_op) self.assertEqual(str(self.logged_message).find(t.name), -1) time.sleep(1.0) self.logged_message = '' mon_sess.run(train_op) self.assertRegexpMatches(str(self.logged_message), t.name) self.logged_message = '' hook.end(sess) if at_end: self.assertRegexpMatches(str(self.logged_message), t.name) else: # assertNotRegexpMatches is not supported by python 3.1 and later self.assertEqual(str(self.logged_message).find(t.name), -1)
def test_save_secs_saving_once_every_three_steps(self): hook = basic_session_run_hooks.SummarySaverHook( save_secs=0.9, summary_writer=self.summary_writer, summary_op=self.summary_op) with self.test_session() as sess: hook.begin() sess.run(variables_lib.global_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) for _ in range(8): mon_sess.run(self.train_op) time.sleep(0.3) hook.end(sess) self.summary_writer.assert_summaries(test_case=self, expected_logdir=self.log_dir, expected_summaries={ 1: { 'my_summary': 1.0 }, 4: { 'my_summary': 2.0 }, 7: { 'my_summary': 3.0 }, })
def test_feeding_placeholder(self): with ops.Graph().as_default(), session_lib.Session() as sess: x = array_ops.placeholder(dtype=dtypes.float32) y = x + 1 hook = basic_session_run_hooks.FeedFnHook(feed_fn=lambda: {x: 1.0}) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) self.assertEqual(mon_sess.run(y), 2)
def testHookNotExceedingLimit(self): def _watch_fn(fetches, feeds): del fetches, feeds return "DebugIdentity", r".*delta.*", r".*" dumping_hook = hooks.DumpingDebugHook( self.session_root, watch_fn=_watch_fn, log_usage=False) mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook]) mon_sess.run(self.inc_v)
def test_feeding_placeholder(self): with ops.Graph().as_default(), session_lib.Session() as sess: x = array_ops.placeholder(dtype=dtypes.float32) y = x + 1 hook = basic_session_run_hooks.FeedFnHook( feed_fn=lambda: {x: 1.0}) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) self.assertEqual(mon_sess.run(y), 2)
def test_sweeps(self): is_row_sweep_var = variables.Variable(True) is_sweep_done_var = variables.Variable(False) init_done = variables.Variable(False) row_prep_done = variables.Variable(False) col_prep_done = variables.Variable(False) row_train_done = variables.Variable(False) col_train_done = variables.Variable(False) init_op = state_ops.assign(init_done, True) row_prep_op = state_ops.assign(row_prep_done, True) col_prep_op = state_ops.assign(col_prep_done, True) row_train_op = state_ops.assign(row_train_done, True) col_train_op = state_ops.assign(col_train_done, True) train_op = control_flow_ops.no_op() switch_op = control_flow_ops.group( state_ops.assign(is_sweep_done_var, False), state_ops.assign(is_row_sweep_var, math_ops.logical_not(is_row_sweep_var))) mark_sweep_done = state_ops.assign(is_sweep_done_var, True) with self.test_session() as sess: sweep_hook = wals_lib._SweepHook(is_row_sweep_var, is_sweep_done_var, init_op, [row_prep_op], [col_prep_op], row_train_op, col_train_op, switch_op) mon_sess = monitored_session._HookedSession(sess, [sweep_hook]) sess.run([variables.global_variables_initializer()]) # Row sweep. mon_sess.run(train_op) self.assertTrue(sess.run(init_done), msg='init op not run by the Sweephook') self.assertTrue(sess.run(row_prep_done), msg='row_prep_op not run by the SweepHook') self.assertTrue(sess.run(row_train_done), msg='row_train_op not run by the SweepHook') self.assertTrue( sess.run(is_row_sweep_var), msg='Row sweep is not complete but is_row_sweep_var is False.') # Col sweep. mon_sess.run(mark_sweep_done) mon_sess.run(train_op) self.assertTrue(sess.run(col_prep_done), msg='col_prep_op not run by the SweepHook') self.assertTrue(sess.run(col_train_done), msg='col_train_op not run by the SweepHook') self.assertFalse( sess.run(is_row_sweep_var), msg='Col sweep is not complete but is_row_sweep_var is True.') # Row sweep. mon_sess.run(mark_sweep_done) mon_sess.run(train_op) self.assertTrue( sess.run(is_row_sweep_var), msg='Col sweep is complete but is_row_sweep_var is False.')
def test_sweeps(self): def ind_feed(row_indices, col_indices): return { self._input_row_indices_ph: row_indices, self._input_col_indices_ph: col_indices } with self.test_session() as sess: is_row_sweep_var = variables.Variable(True) sweep_hook = wals_lib._SweepHook( is_row_sweep_var, self._train_op, self._num_rows, self._num_cols, self._input_row_indices_ph, self._input_col_indices_ph, self._row_prep_ops, self._col_prep_ops, self._init_ops) mon_sess = monitored_session._HookedSession(sess, [sweep_hook]) sess.run([variables.global_variables_initializer()]) # Init ops should run before the first run. Row sweep not completed. mon_sess.run(self._train_op, ind_feed([0, 1, 2], [])) self.assertTrue(sess.run(self._init_done), msg='init ops not run by the sweep_hook') self.assertTrue(sess.run(self._row_prep_done), msg='row_prep not run by the sweep_hook') self.assertTrue( sess.run(is_row_sweep_var), msg='Row sweep is not complete but is_row_sweep is ' 'False.') # Row sweep completed. mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6])) self.assertFalse( sess.run(is_row_sweep_var), msg='Row sweep is complete but is_row_sweep is True.') self.assertTrue( sweep_hook._is_sweep_done, msg='Sweep is complete but is_sweep_done is False.') # Col init ops should run. Col sweep not completed. mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4])) self.assertTrue(sess.run(self._col_prep_done), msg='col_prep not run by the sweep_hook') self.assertFalse( sess.run(is_row_sweep_var), msg='Col sweep is not complete but is_row_sweep is ' 'True.') self.assertFalse( sweep_hook._is_sweep_done, msg='Sweep is not complete but is_sweep_done is True.') # Col sweep completed. mon_sess.run(self._train_op, ind_feed([], [4, 5, 6])) self.assertTrue( sess.run(is_row_sweep_var), msg='Col sweep is complete but is_row_sweep is False') self.assertTrue( sweep_hook._is_sweep_done, msg='Sweep is complete but is_sweep_done is False.')
def test_sweeps(self): def ind_feed(row_indices, col_indices): return { self._input_row_indices_ph: row_indices, self._input_col_indices_ph: col_indices } with self.test_session() as sess: is_row_sweep_var = variables.Variable(True) completed_sweeps_var = variables.Variable(0) sweep_hook = wals_lib._SweepHook( is_row_sweep_var, [self._train_op], self._num_rows, self._num_cols, self._input_row_indices_ph, self._input_col_indices_ph, self._row_prep_ops, self._col_prep_ops, self._init_ops, completed_sweeps_var) mon_sess = monitored_session._HookedSession(sess, [sweep_hook]) sess.run([variables.global_variables_initializer()]) # Init ops should run before the first run. Row sweep not completed. mon_sess.run(self._train_op, ind_feed([0, 1, 2], [])) self.assertTrue(sess.run(self._init_done), msg='init ops not run by the sweep_hook') self.assertTrue(sess.run(self._row_prep_done), msg='row_prep not run by the sweep_hook') self.assertTrue(sess.run(is_row_sweep_var), msg='Row sweep is not complete but is_row_sweep is ' 'False.') # Row sweep completed. mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6])) self.assertTrue(sess.run(completed_sweeps_var) == 1, msg='Completed sweeps should be equal to 1.') self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), msg='Sweep is complete but is_sweep_done is False.') # Col init ops should run. Col sweep not completed. mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4])) self.assertTrue(sess.run(self._col_prep_done), msg='col_prep not run by the sweep_hook') self.assertFalse(sess.run(is_row_sweep_var), msg='Col sweep is not complete but is_row_sweep is ' 'True.') self.assertFalse(sess.run(sweep_hook._is_sweep_done_var), msg='Sweep is not complete but is_sweep_done is True.') # Col sweep completed. mon_sess.run(self._train_op, ind_feed([], [4, 5, 6])) self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), msg='Sweep is complete but is_sweep_done is False.') self.assertTrue(sess.run(completed_sweeps_var) == 2, msg='Completed sweeps should be equal to 2.')
def test_print_formatter(self): with ops.Graph().as_default(), session_lib.Session() as sess: t = constant_op.constant(42.0, name='foo') train_op = constant_op.constant(3) hook = basic_session_run_hooks.LoggingTensorHook( tensors=[t.name], every_n_iter=10, formatter=lambda items: 'qqq=%s' % items[t.name]) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) sess.run(variables_lib.global_variables_initializer()) mon_sess.run(train_op) self.assertEqual(self.logged_message[0], 'qqq=42.0')
def test_saves_when_saver_and_scaffold_both_missing(self): with self.graph.as_default(): hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_steps=1) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(self.train_op) self.assertEqual(1, checkpoint_utils.load_variable(self.model_dir, self.global_step.name))
def testDumpingDebugHookWithoutWatchFnWorks(self): dumping_hook = hooks.DumpingDebugHook(self.session_root, log_usage=False) mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook]) mon_sess.run(self.inc_v) dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) self.assertEqual(1, len(dump_dirs)) self._assert_correct_run_subdir_naming(os.path.basename(dump_dirs[0])) dump = debug_data.DebugDumpDir(dump_dirs[0]) self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity")) self.assertEqual(repr(self.inc_v), dump.run_fetches_info) self.assertEqual(repr(None), dump.run_feed_keys_info)
def test_print_first_step(self): # if it runs every iteration, first iteration has None duration. with ops.Graph().as_default(), session_lib.Session() as sess: t = constant_op.constant(42.0, name='foo') train_op = constant_op.constant(3) hook = basic_session_run_hooks.LoggingTensorHook( tensors={'foo': t}, every_n_iter=1) hook.begin() mon_sess = monitored_session._HookedSession(sess, [hook]) sess.run(variables_lib.global_variables_initializer()) mon_sess.run(train_op) self.assertRegexpMatches(str(self.logged_message), 'foo') # in first run, elapsed time is None. self.assertEqual(str(self.logged_message).find('sec'), -1)