示例#1
0
    def testGrpcDebugHookWithStatelessWatchFnWorks(self):
        # Perform some set up. Specifically, construct a simple TensorFlow graph and
        # create a watch function for certain ops.
        def watch_fn(feeds, fetch_keys):
            del feeds, fetch_keys
            return framework.WatchOptions(
                debug_ops=["DebugIdentity", "DebugNumericSummary"],
                node_name_regex_whitelist=r".*/read",
                op_type_regex_whitelist=None,
                tolerate_debug_op_creation_failures=True)

        u = variables.Variable(2.1, name="u")
        v = variables.Variable(20.0, name="v")
        w = math_ops.multiply(u, v, name="w")

        sess = session.Session(
            config=session_debug_testlib.no_rewrite_session_config())
        sess.run(u.initializer)
        sess.run(v.initializer)

        # Create a hook. One could use this hook with say a tflearn Estimator.
        # However, we use a HookedSession in this test to avoid depending on the
        # internal implementation of Estimators.
        grpc_debug_hook = hooks.GrpcDebugHook(
            ["localhost:%d" % self._server_port], watch_fn=watch_fn)
        sess = monitored_session._HookedSession(sess, [grpc_debug_hook])

        # Run the hooked session. This should stream tensor data to the GRPC
        # endpoints.
        w_result = sess.run(w)

        # Verify that the hook monitored the correct tensors.
        self.assertAllClose(42.0, w_result)
        dump = debug_data.DebugDumpDir(self._dump_root)
        self.assertEqual(4, dump.size)
        self.assertAllClose([2.1],
                            dump.get_tensors("u/read", 0, "DebugIdentity"))
        self.assertEqual(
            14, len(dump.get_tensors("u/read", 0, "DebugNumericSummary")[0]))
        self.assertAllClose([20.0],
                            dump.get_tensors("v/read", 0, "DebugIdentity"))
        self.assertEqual(
            14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0]))
示例#2
0
 def testConstructGrpcDebugHookWithOrWithouGrpcInUrlWorks(self):
     hooks.GrpcDebugHook(["grpc://foo:42424"])
     hooks.GrpcDebugHook(["foo:42424"])
示例#3
0
 def testConstructGrpcDebugHookWithGrpcInUrlRaisesValueError(self):
     """Tests that the hook raises an error if the URL starts with grpc://."""
     with self.assertRaises(ValueError):
         hooks.GrpcDebugHook(["grpc://foo:42"])