Ejemplo n.º 1
0
 def testDumpingOnASingleRunWorksWithRelativePathForDebugDumpDir(self):
     sess = dumping_wrapper.DumpingDebugWrapperSession(
         self.sess, session_root=self.session_root, log_usage=False)
     sess.run(self.inc_v)
     dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
     cwd = os.getcwd()
     try:
         os.chdir(self.session_root)
         dump = debug_data.DebugDumpDir(
             os.path.relpath(dump_dirs[0], self.session_root))
         self.assertAllClose([10.0],
                             dump.get_tensors("v", 0, "DebugIdentity"))
     finally:
         os.chdir(cwd)
Ejemplo n.º 2
0
    def testDumpingOnASingleRunWorks(self):
        sess = dumping_wrapper.DumpingDebugWrapperSession(
            self.sess, session_root=self.session_root, log_usage=False)
        sess.run(self.inc_v)

        dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
        self.assertEqual(1, len(dump_dirs))

        self._assert_correct_run_subdir_naming(os.path.basename(dump_dirs[0]))
        dump = debug_data.DebugDumpDir(dump_dirs[0])
        self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity"))

        self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
        self.assertEqual(repr(None), dump.run_feed_keys_info)
Ejemplo n.º 3
0
    def testDumpingOnMultipleRunsWorks(self):
        sess = dumping_wrapper.DumpingDebugWrapperSession(
            self.sess, session_root=self.session_root, log_usage=False)
        for _ in range(3):
            sess.run(self.inc_v)

        dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
        dump_dirs = sorted(
            dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1]))
        self.assertEqual(3, len(dump_dirs))
        for i, dump_dir in enumerate(dump_dirs):
            self._assert_correct_run_subdir_naming(os.path.basename(dump_dir))
            dump = debug_data.DebugDumpDir(dump_dir)
            self.assertAllClose([10.0 + 1.0 * i],
                                dump.get_tensors("v", 0, "DebugIdentity"))
            self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
            self.assertEqual(repr(None), dump.run_feed_keys_info)
Ejemplo n.º 4
0
 def testWrapperSessionExceedingLimit(self):
   def _watch_fn(fetches, feeds):
     del fetches, feeds
     return "DebugIdentity", r".*delta.*", r".*"
   sess = dumping_wrapper.DumpingDebugWrapperSession(
       self.sess, session_root=self.session_root,
       watch_fn=_watch_fn, log_usage=False)
   # Due to the watch function, each run should dump only 1 tensor,
   # which has a size of 4 bytes, which corresponds to the dumped 'delta:0'
   # tensor of scalar shape and float32 dtype.
   # 1st run should pass, after which the disk usage is at 4 bytes.
   sess.run(self.inc_v)
   # 2nd run should also pass, after which 8 bytes are used.
   sess.run(self.inc_v)
   # 3rd run should fail, because the total byte count (12) exceeds the
   # limit (10)
   with self.assertRaises(ValueError):
     sess.run(self.inc_v)
Ejemplo n.º 5
0
  def testDumpingFromMultipleThreadsObeysThreadNameFilter(self):
    sess = dumping_wrapper.DumpingDebugWrapperSession(
        self.sess, session_root=self.session_root, log_usage=False,
        thread_name_filter=r"MainThread$")

    self.assertAllClose(1.0, sess.run(self.delta))
    def child_thread_job():
      sess.run(sess.run(self.eta))

    thread = threading.Thread(name="ChildThread", target=child_thread_job)
    thread.start()
    thread.join()

    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
    self.assertEqual(1, len(dump_dirs))
    dump = debug_data.DebugDumpDir(dump_dirs[0])
    self.assertEqual(1, dump.size)
    self.assertEqual("delta", dump.dumped_tensor_data[0].node_name)
Ejemplo n.º 6
0
  def testDumpingWithLegacyWatchFnOnFetchesWorks(self):
    """Use a watch_fn that returns different whitelists for different runs."""

    def watch_fn(fetches, feeds):
      del feeds
      # A watch_fn that picks fetch name.
      if fetches.name == "inc_v:0":
        # If inc_v, watch everything.
        return "DebugIdentity", r".*", r".*"
      else:
        # If dec_v, watch nothing.
        return "DebugIdentity", r"$^", r"$^"

    sess = dumping_wrapper.DumpingDebugWrapperSession(
        self.sess,
        session_root=self.session_root,
        watch_fn=watch_fn,
        log_usage=False)

    for _ in range(3):
      sess.run(self.inc_v)
      sess.run(self.dec_v)

    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
    dump_dirs = sorted(
        dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1]))
    self.assertEqual(6, len(dump_dirs))

    for i, dump_dir in enumerate(dump_dirs):
      self._assert_correct_run_subdir_naming(os.path.basename(dump_dir))
      dump = debug_data.DebugDumpDir(dump_dir)
      if i % 2 == 0:
        self.assertGreater(dump.size, 0)
        self.assertAllClose([10.0 - 0.4 * (i / 2)],
                            dump.get_tensors("v", 0, "DebugIdentity"))
        self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
        self.assertEqual(repr(None), dump.run_feed_keys_info)
      else:
        self.assertEqual(0, dump.size)
        self.assertEqual(repr(self.dec_v), dump.run_fetches_info)
        self.assertEqual(repr(None), dump.run_feed_keys_info)
Ejemplo n.º 7
0
    def testDumpingWithLegacyWatchFnWithNonDefaultDebugOpsWorks(self):
        """Use a watch_fn that specifies non-default debug ops."""
        def watch_fn(fetches, feeds):
            del fetches, feeds
            return ["DebugIdentity", "DebugNumericSummary"], r".*", r".*"

        sess = dumping_wrapper.DumpingDebugWrapperSession(
            self.sess,
            session_root=self.session_root,
            watch_fn=watch_fn,
            log_usage=False)

        sess.run(self.inc_v)

        dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
        self.assertEqual(1, len(dump_dirs))
        dump = debug_data.DebugDumpDir(dump_dirs[0])

        self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity"))
        self.assertEqual(
            14, len(dump.get_tensors("v", 0, "DebugNumericSummary")[0]))
Ejemplo n.º 8
0
    def before_run(self, run_context):
        reset_disk_byte_usage = False
        if not self._session_wrapper:
            self._session_wrapper = dumping_wrapper.DumpingDebugWrapperSession(
                run_context.session,
                self._session_root,
                watch_fn=self._watch_fn,
                thread_name_filter=self._thread_name_filter,
                log_usage=self._log_usage)
            reset_disk_byte_usage = True

        self._session_wrapper.increment_run_call_count()

        # pylint: disable=protected-access
        debug_urls, watch_options = self._session_wrapper._prepare_run_watch_config(
            run_context.original_args.fetches,
            run_context.original_args.feed_dict)
        # pylint: enable=protected-access
        run_options = config_pb2.RunOptions()
        debug_utils.watch_graph(
            run_options,
            run_context.session.graph,
            debug_urls=debug_urls,
            debug_ops=watch_options.debug_ops,
            node_name_regex_allowlist=watch_options.node_name_regex_allowlist,
            op_type_regex_allowlist=watch_options.op_type_regex_allowlist,
            tensor_dtype_regex_allowlist=watch_options.
            tensor_dtype_regex_allowlist,
            tolerate_debug_op_creation_failures=(
                watch_options.tolerate_debug_op_creation_failures),
            reset_disk_byte_usage=reset_disk_byte_usage)

        run_args = session_run_hook.SessionRunArgs(None,
                                                   feed_dict=None,
                                                   options=run_options)
        return run_args
Ejemplo n.º 9
0
 def testDumpingWrapperWithEmptyFetchWorks(self):
     sess = dumping_wrapper.DumpingDebugWrapperSession(
         self.sess, session_root=self.session_root, log_usage=False)
     sess.run([])
Ejemplo n.º 10
0
 def dumping_wrapper(sess):  # pylint: disable=invalid-name
   return dumping_wrapper_lib.DumpingDebugWrapperSession(sess, dump_root)