def testAdditionalHooks(self): checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt') log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/') # First, save out the current model to a checkpoint: self._prepareCheckpoint(checkpoint_path) # Next, determine the metric to evaluate: value_op, update_op = metric_ops.streaming_accuracy( self._predictions, self._labels) dumping_root = os.path.join(self.get_temp_dir(), 'tfdbg_dump_dir') dumping_hook = hooks.DumpingDebugHook(dumping_root, log_usage=False) try: # Run the evaluation and verify the results: accuracy_value = evaluation.evaluate_once('', checkpoint_path, log_dir, eval_op=update_op, final_op=value_op, hooks=[dumping_hook]) self.assertAlmostEqual(accuracy_value, self._expected_accuracy) dump = debug_data.DebugDumpDir( glob.glob(os.path.join(dumping_root, 'run_*'))[0]) # Here we simply assert that the dumped data has been loaded and is # non-empty. We do not care about the detailed model-internal tensors or # their values. self.assertTrue(dump.dumped_tensor_data) finally: if os.path.isdir(dumping_root): shutil.rmtree(dumping_root)
def testDumpingDebugHookWithStatefulLegacyWatchFnWorks(self): watch_fn_state = {"run_counter": 0} def counting_watch_fn(fetches, feed_dict): del fetches, feed_dict watch_fn_state["run_counter"] += 1 if watch_fn_state["run_counter"] % 2 == 1: # If odd-index run (1-based), watch everything. return "DebugIdentity", r".*", r".*" else: # If even-index run, watch nothing. return "DebugIdentity", r"$^", r"$^" dumping_hook = hooks.DumpingDebugHook( self.session_root, watch_fn=counting_watch_fn, log_usage=False) mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook]) for _ in range(4): mon_sess.run(self.inc_v) dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) dump_dirs = sorted( dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1])) self.assertEqual(4, len(dump_dirs)) for i, dump_dir in enumerate(dump_dirs): self._assert_correct_run_subdir_naming(os.path.basename(dump_dir)) dump = debug_data.DebugDumpDir(dump_dir) if i % 2 == 0: self.assertAllClose([10.0 + 1.0 * i], dump.get_tensors("v", 0, "DebugIdentity")) else: self.assertEqual(0, dump.size) self.assertEqual(repr(self.inc_v), dump.run_fetches_info) self.assertEqual(repr(None), dump.run_feed_keys_info)
def testHookNotExceedingLimit(self): def _watch_fn(fetches, feeds): del fetches, feeds return "DebugIdentity", r".*delta.*", r".*" dumping_hook = hooks.DumpingDebugHook( self.session_root, watch_fn=_watch_fn, log_usage=False) mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook]) mon_sess.run(self.inc_v)
def testDumpingDebugHookWithoutWatchFnWorks(self): dumping_hook = hooks.DumpingDebugHook(self.session_root, log_usage=False) mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook]) mon_sess.run(self.inc_v) dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) self.assertEqual(1, len(dump_dirs)) self._assert_correct_run_subdir_naming(os.path.basename(dump_dirs[0])) dump = debug_data.DebugDumpDir(dump_dirs[0]) self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity")) self.assertEqual(repr(self.inc_v), dump.run_fetches_info) self.assertEqual(repr(None), dump.run_feed_keys_info)
def testHookExceedingLimit(self): def _watch_fn(fetches, feeds): del fetches, feeds return "DebugIdentity", r".*delta.*", r".*" dumping_hook = hooks.DumpingDebugHook( self.session_root, watch_fn=_watch_fn, log_usage=False) mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook]) # Like in `testWrapperSessionExceedingLimit`, the first two calls # should be within the byte limit, but the third one should error # out due to exceeding the limit. mon_sess.run(self.inc_v) mon_sess.run(self.inc_v) with self.assertRaises(ValueError): mon_sess.run(self.inc_v)
def testDumpingDebugHookWithStatefulWatchFnWorks(self): watch_fn_state = {"run_counter": 0} def counting_watch_fn(fetches, feed_dict): del fetches, feed_dict watch_fn_state["run_counter"] += 1 if watch_fn_state["run_counter"] % 2 == 1: # If odd-index run (1-based), watch every ref-type tensor. return framework.WatchOptions( debug_ops="DebugIdentity", tensor_dtype_regex_whitelist=".*_ref") else: # If even-index run, watch nothing. return framework.WatchOptions(debug_ops="DebugIdentity", node_name_regex_whitelist=r"^$", op_type_regex_whitelist=r"^$") dumping_hook = hooks.DumpingDebugHook(self.session_root, watch_fn=counting_watch_fn, log_usage=False) mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook]) for _ in range(4): mon_sess.run(self.inc_v) dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) dump_dirs = sorted( dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1])) self.assertEqual(4, len(dump_dirs)) for i, dump_dir in enumerate(dump_dirs): self._assert_correct_run_subdir_naming(os.path.basename(dump_dir)) dump = debug_data.DebugDumpDir(dump_dir) if i % 2 == 0: self.assertAllClose([10.0 + 1.0 * i], dump.get_tensors("v", 0, "DebugIdentity")) self.assertNotIn( "delta", [datum.node_name for datum in dump.dumped_tensor_data]) else: self.assertEqual(0, dump.size) self.assertEqual(repr(self.inc_v), dump.run_fetches_info) self.assertEqual(repr(None), dump.run_feed_keys_info)