def testDumpingOnASingleRunWorksWithRelativePathForDebugDumpDir(self): sess = dumping_wrapper.DumpingDebugWrapperSession( self.sess, session_root=self.session_root, log_usage=False) sess.run(self.inc_v) dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) cwd = os.getcwd() try: os.chdir(self.session_root) dump = debug_data.DebugDumpDir( os.path.relpath(dump_dirs[0], self.session_root)) self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity")) finally: os.chdir(cwd)
def testDumpingOnASingleRunWorks(self): sess = dumping_wrapper.DumpingDebugWrapperSession( self.sess, session_root=self.session_root, log_usage=False) sess.run(self.inc_v) dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) self.assertEqual(1, len(dump_dirs)) self._assert_correct_run_subdir_naming(os.path.basename(dump_dirs[0])) dump = debug_data.DebugDumpDir(dump_dirs[0]) self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity")) self.assertEqual(repr(self.inc_v), dump.run_fetches_info) self.assertEqual(repr(None), dump.run_feed_keys_info)
def testDumpingOnMultipleRunsWorks(self): sess = dumping_wrapper.DumpingDebugWrapperSession( self.sess, session_root=self.session_root, log_usage=False) for _ in range(3): sess.run(self.inc_v) dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) dump_dirs = sorted( dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1])) self.assertEqual(3, len(dump_dirs)) for i, dump_dir in enumerate(dump_dirs): self._assert_correct_run_subdir_naming(os.path.basename(dump_dir)) dump = debug_data.DebugDumpDir(dump_dir) self.assertAllClose([10.0 + 1.0 * i], dump.get_tensors("v", 0, "DebugIdentity")) self.assertEqual(repr(self.inc_v), dump.run_fetches_info) self.assertEqual(repr(None), dump.run_feed_keys_info)
def testWrapperSessionExceedingLimit(self): def _watch_fn(fetches, feeds): del fetches, feeds return "DebugIdentity", r".*delta.*", r".*" sess = dumping_wrapper.DumpingDebugWrapperSession( self.sess, session_root=self.session_root, watch_fn=_watch_fn, log_usage=False) # Due to the watch function, each run should dump only 1 tensor, # which has a size of 4 bytes, which corresponds to the dumped 'delta:0' # tensor of scalar shape and float32 dtype. # 1st run should pass, after which the disk usage is at 4 bytes. sess.run(self.inc_v) # 2nd run should also pass, after which 8 bytes are used. sess.run(self.inc_v) # 3rd run should fail, because the total byte count (12) exceeds the # limit (10) with self.assertRaises(ValueError): sess.run(self.inc_v)
def testDumpingFromMultipleThreadsObeysThreadNameFilter(self): sess = dumping_wrapper.DumpingDebugWrapperSession( self.sess, session_root=self.session_root, log_usage=False, thread_name_filter=r"MainThread$") self.assertAllClose(1.0, sess.run(self.delta)) def child_thread_job(): sess.run(sess.run(self.eta)) thread = threading.Thread(name="ChildThread", target=child_thread_job) thread.start() thread.join() dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) self.assertEqual(1, len(dump_dirs)) dump = debug_data.DebugDumpDir(dump_dirs[0]) self.assertEqual(1, dump.size) self.assertEqual("delta", dump.dumped_tensor_data[0].node_name)
def testDumpingWithLegacyWatchFnOnFetchesWorks(self): """Use a watch_fn that returns different whitelists for different runs.""" def watch_fn(fetches, feeds): del feeds # A watch_fn that picks fetch name. if fetches.name == "inc_v:0": # If inc_v, watch everything. return "DebugIdentity", r".*", r".*" else: # If dec_v, watch nothing. return "DebugIdentity", r"$^", r"$^" sess = dumping_wrapper.DumpingDebugWrapperSession( self.sess, session_root=self.session_root, watch_fn=watch_fn, log_usage=False) for _ in range(3): sess.run(self.inc_v) sess.run(self.dec_v) dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) dump_dirs = sorted( dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1])) self.assertEqual(6, len(dump_dirs)) for i, dump_dir in enumerate(dump_dirs): self._assert_correct_run_subdir_naming(os.path.basename(dump_dir)) dump = debug_data.DebugDumpDir(dump_dir) if i % 2 == 0: self.assertGreater(dump.size, 0) self.assertAllClose([10.0 - 0.4 * (i / 2)], dump.get_tensors("v", 0, "DebugIdentity")) self.assertEqual(repr(self.inc_v), dump.run_fetches_info) self.assertEqual(repr(None), dump.run_feed_keys_info) else: self.assertEqual(0, dump.size) self.assertEqual(repr(self.dec_v), dump.run_fetches_info) self.assertEqual(repr(None), dump.run_feed_keys_info)
def testDumpingWithLegacyWatchFnWithNonDefaultDebugOpsWorks(self): """Use a watch_fn that specifies non-default debug ops.""" def watch_fn(fetches, feeds): del fetches, feeds return ["DebugIdentity", "DebugNumericSummary"], r".*", r".*" sess = dumping_wrapper.DumpingDebugWrapperSession( self.sess, session_root=self.session_root, watch_fn=watch_fn, log_usage=False) sess.run(self.inc_v) dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) self.assertEqual(1, len(dump_dirs)) dump = debug_data.DebugDumpDir(dump_dirs[0]) self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity")) self.assertEqual( 14, len(dump.get_tensors("v", 0, "DebugNumericSummary")[0]))
def before_run(self, run_context): reset_disk_byte_usage = False if not self._session_wrapper: self._session_wrapper = dumping_wrapper.DumpingDebugWrapperSession( run_context.session, self._session_root, watch_fn=self._watch_fn, thread_name_filter=self._thread_name_filter, log_usage=self._log_usage) reset_disk_byte_usage = True self._session_wrapper.increment_run_call_count() # pylint: disable=protected-access debug_urls, watch_options = self._session_wrapper._prepare_run_watch_config( run_context.original_args.fetches, run_context.original_args.feed_dict) # pylint: enable=protected-access run_options = config_pb2.RunOptions() debug_utils.watch_graph( run_options, run_context.session.graph, debug_urls=debug_urls, debug_ops=watch_options.debug_ops, node_name_regex_allowlist=watch_options.node_name_regex_allowlist, op_type_regex_allowlist=watch_options.op_type_regex_allowlist, tensor_dtype_regex_allowlist=watch_options. tensor_dtype_regex_allowlist, tolerate_debug_op_creation_failures=( watch_options.tolerate_debug_op_creation_failures), reset_disk_byte_usage=reset_disk_byte_usage) run_args = session_run_hook.SessionRunArgs(None, feed_dict=None, options=run_options) return run_args
def testDumpingWrapperWithEmptyFetchWorks(self): sess = dumping_wrapper.DumpingDebugWrapperSession( self.sess, session_root=self.session_root, log_usage=False) sess.run([])
def dumping_wrapper(sess): # pylint: disable=invalid-name return dumping_wrapper_lib.DumpingDebugWrapperSession(sess, dump_root)