Esempio n. 1
0
    def testAdditionalHooks(self):
        checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
        log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/')

        # First, save out the current model to a checkpoint:
        self._prepareCheckpoint(checkpoint_path)

        # Next, determine the metric to evaluate:
        value_op, update_op = metric_ops.streaming_accuracy(
            self._predictions, self._labels)

        dumping_root = os.path.join(self.get_temp_dir(), 'tfdbg_dump_dir')
        dumping_hook = hooks.DumpingDebugHook(dumping_root, log_usage=False)
        try:
            # Run the evaluation and verify the results:
            accuracy_value = evaluation.evaluate_once('',
                                                      checkpoint_path,
                                                      log_dir,
                                                      eval_op=update_op,
                                                      final_op=value_op,
                                                      hooks=[dumping_hook])
            self.assertAlmostEqual(accuracy_value, self._expected_accuracy)

            dump = debug_data.DebugDumpDir(
                glob.glob(os.path.join(dumping_root, 'run_*'))[0])
            # Here we simply assert that the dumped data has been loaded and is
            # non-empty. We do not care about the detailed model-internal tensors or
            # their values.
            self.assertTrue(dump.dumped_tensor_data)
        finally:
            if os.path.isdir(dumping_root):
                shutil.rmtree(dumping_root)
Esempio n. 2
0
  def testDumpingDebugHookWithStatefulLegacyWatchFnWorks(self):
    watch_fn_state = {"run_counter": 0}

    def counting_watch_fn(fetches, feed_dict):
      del fetches, feed_dict
      watch_fn_state["run_counter"] += 1
      if watch_fn_state["run_counter"] % 2 == 1:
        # If odd-index run (1-based), watch everything.
        return "DebugIdentity", r".*", r".*"
      else:
        # If even-index run, watch nothing.
        return "DebugIdentity", r"$^", r"$^"

    dumping_hook = hooks.DumpingDebugHook(
        self.session_root, watch_fn=counting_watch_fn, log_usage=False)
    mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
    for _ in range(4):
      mon_sess.run(self.inc_v)

    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
    dump_dirs = sorted(
        dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1]))
    self.assertEqual(4, len(dump_dirs))

    for i, dump_dir in enumerate(dump_dirs):
      self._assert_correct_run_subdir_naming(os.path.basename(dump_dir))
      dump = debug_data.DebugDumpDir(dump_dir)
      if i % 2 == 0:
        self.assertAllClose([10.0 + 1.0 * i],
                            dump.get_tensors("v", 0, "DebugIdentity"))
      else:
        self.assertEqual(0, dump.size)

      self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
      self.assertEqual(repr(None), dump.run_feed_keys_info)
Esempio n. 3
0
 def testHookNotExceedingLimit(self):
   def _watch_fn(fetches, feeds):
     del fetches, feeds
     return "DebugIdentity", r".*delta.*", r".*"
   dumping_hook = hooks.DumpingDebugHook(
       self.session_root, watch_fn=_watch_fn, log_usage=False)
   mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
   mon_sess.run(self.inc_v)
Esempio n. 4
0
  def testDumpingDebugHookWithoutWatchFnWorks(self):
    dumping_hook = hooks.DumpingDebugHook(self.session_root, log_usage=False)
    mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
    mon_sess.run(self.inc_v)

    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
    self.assertEqual(1, len(dump_dirs))

    self._assert_correct_run_subdir_naming(os.path.basename(dump_dirs[0]))
    dump = debug_data.DebugDumpDir(dump_dirs[0])
    self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity"))

    self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
    self.assertEqual(repr(None), dump.run_feed_keys_info)
Esempio n. 5
0
 def testHookExceedingLimit(self):
   def _watch_fn(fetches, feeds):
     del fetches, feeds
     return "DebugIdentity", r".*delta.*", r".*"
   dumping_hook = hooks.DumpingDebugHook(
       self.session_root, watch_fn=_watch_fn, log_usage=False)
   mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
   # Like in `testWrapperSessionExceedingLimit`, the first two calls
   # should be within the byte limit, but the third one should error
   # out due to exceeding the limit.
   mon_sess.run(self.inc_v)
   mon_sess.run(self.inc_v)
   with self.assertRaises(ValueError):
     mon_sess.run(self.inc_v)
Esempio n. 6
0
    def testDumpingDebugHookWithStatefulWatchFnWorks(self):
        watch_fn_state = {"run_counter": 0}

        def counting_watch_fn(fetches, feed_dict):
            del fetches, feed_dict
            watch_fn_state["run_counter"] += 1
            if watch_fn_state["run_counter"] % 2 == 1:
                # If odd-index run (1-based), watch every ref-type tensor.
                return framework.WatchOptions(
                    debug_ops="DebugIdentity",
                    tensor_dtype_regex_whitelist=".*_ref")
            else:
                # If even-index run, watch nothing.
                return framework.WatchOptions(debug_ops="DebugIdentity",
                                              node_name_regex_whitelist=r"^$",
                                              op_type_regex_whitelist=r"^$")

        dumping_hook = hooks.DumpingDebugHook(self.session_root,
                                              watch_fn=counting_watch_fn,
                                              log_usage=False)
        mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
        for _ in range(4):
            mon_sess.run(self.inc_v)

        dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
        dump_dirs = sorted(
            dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1]))
        self.assertEqual(4, len(dump_dirs))

        for i, dump_dir in enumerate(dump_dirs):
            self._assert_correct_run_subdir_naming(os.path.basename(dump_dir))
            dump = debug_data.DebugDumpDir(dump_dir)
            if i % 2 == 0:
                self.assertAllClose([10.0 + 1.0 * i],
                                    dump.get_tensors("v", 0, "DebugIdentity"))
                self.assertNotIn(
                    "delta",
                    [datum.node_name for datum in dump.dumped_tensor_data])
            else:
                self.assertEqual(0, dump.size)

            self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
            self.assertEqual(repr(None), dump.run_feed_keys_info)