def testTensorBoardDebuggerWrapperDisablingTracebackSourceSendingWorks(self):
    with session.Session(config=no_rewrite_session_config()) as sess:
      v_1 = variables.Variable(50.0, name="v_1")
      v_2 = variables.Variable(-50.0, name="v_2")
      delta_1 = constant_op.constant(5.0, name="delta_1")
      delta_2 = constant_op.constant(-5.0, name="delta_2")
      inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
      inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")

      sess.run(variables.global_variables_initializer())

      # Disable the sending of traceback and source code.
      sess = grpc_wrapper.TensorBoardDebugWrapperSession(
          sess, self._debug_server_url_1, send_traceback_and_source_code=False)

      for i in xrange(4):
        self._server_1.clear_data()

        if i == 0:
          self._server_1.request_watch(
              "delta_1", 0, "DebugIdentity", breakpoint=True)

        output = sess.run([inc_v_1, inc_v_2])
        self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output)

        # No op traceback or source code should have been received by the debug
        # server due to the disabling above.
        with self.assertRaisesRegexp(
            ValueError, r"Op .*delta_1.* does not exist"):
          self.assertTrue(self._server_1.query_op_traceback("delta_1"))
        with self.assertRaisesRegexp(
            ValueError, r".* has not received any source file"):
          self._server_1.query_source_file_line(__file__, 1)
Beispiel #2
0
    def testTensorBoardDebuggerWrapperToggleBreakpointsWorks(self):
        with session.Session(config=no_rewrite_session_config()) as sess:
            v_1 = variables.Variable(50.0, name="v_1")
            v_2 = variables.Variable(-50.0, name="v_2")
            delta_1 = constant_op.constant(5.0, name="delta_1")
            delta_2 = constant_op.constant(-5.0, name="delta_2")
            inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
            inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")

            sess.run([v_1.initializer, v_2.initializer])

            # The TensorBoardDebugWrapperSession should add a DebugIdentity debug op
            # with attribute gated_grpc=True for every tensor in the graph.
            sess = grpc_wrapper.TensorBoardDebugWrapperSession(
                sess, self._debug_server_url_1)

            for i in xrange(4):
                self._server_1.clear_data()

                if i in (0, 2):
                    # Enable breakpoint at delta_[1,2]:0:DebugIdentity in runs 0 and 2.
                    self._server_1.request_watch("delta_1",
                                                 0,
                                                 "DebugIdentity",
                                                 breakpoint=True)
                    self._server_1.request_watch("delta_2",
                                                 0,
                                                 "DebugIdentity",
                                                 breakpoint=True)
                else:
                    # Disable the breakpoint in runs 1 and 3.
                    self._server_1.request_unwatch("delta_1", 0,
                                                   "DebugIdentity")
                    self._server_1.request_unwatch("delta_2", 0,
                                                   "DebugIdentity")

                output = sess.run([inc_v_1, inc_v_2])
                self.assertAllClose(
                    [50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output)

                if i in (0, 2):
                    # During runs 0 and 2, the server should have received the published
                    # debug tensor delta:0:DebugIdentity. The breakpoint should have been
                    # unblocked by EventReply reponses from the server.
                    self.assertAllClose(
                        [5.0], self._server_1.
                        debug_tensor_values["delta_1:0:DebugIdentity"])
                    self.assertAllClose(
                        [-5.0], self._server_1.
                        debug_tensor_values["delta_2:0:DebugIdentity"])
                    # After the runs, the server should have properly registered the
                    # breakpoints.
                else:
                    # After the end of runs 1 and 3, the server has received the requests
                    # to disable the breakpoint at delta:0:DebugIdentity.
                    self.assertSetEqual(set(), self._server_1.breakpoints)
Beispiel #3
0
    def testTensorBoardDebuggerWrapperToggleBreakpointsWorks(self):
        with session.Session(config=session_debug_testlib.
                             no_rewrite_session_config()) as sess:
            v_1 = variables.Variable(50.0, name="v_1")
            v_2 = variables.Variable(-50.0, name="v_2")
            delta_1 = constant_op.constant(5.0, name="delta_1")
            delta_2 = constant_op.constant(-5.0, name="delta_2")
            inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
            inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")

            sess.run([v_1.initializer, v_2.initializer])

            # The TensorBoardDebugWrapperSession should add a DebugIdentity debug op
            # with attribute gated_grpc=True for every tensor in the graph.
            sess = grpc_wrapper.TensorBoardDebugWrapperSession(
                sess, self._debug_server_url_1)

            for i in xrange(4):
                self._server_1.clear_data()

                if i in (0, 2):
                    # Enable breakpoint at delta_[1,2]:0:DebugIdentity in runs 0 and 2.
                    self._server_1.request_watch("delta_1",
                                                 0,
                                                 "DebugIdentity",
                                                 breakpoint=True)
                    self._server_1.request_watch("delta_2",
                                                 0,
                                                 "DebugIdentity",
                                                 breakpoint=True)
                else:
                    # Disable the breakpoint in runs 1 and 3.
                    self._server_1.request_unwatch("delta_1", 0,
                                                   "DebugIdentity")
                    self._server_1.request_unwatch("delta_2", 0,
                                                   "DebugIdentity")

                output = sess.run([inc_v_1, inc_v_2])
                self.assertAllClose(
                    [50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output)

                if i in (0, 2):
                    # During runs 0 and 2, the server should have received the published
                    # debug tensor delta:0:DebugIdentity. The breakpoint should have been
                    # unblocked by EventReply reponses from the server.
                    self.assertAllClose(
                        [5.0], self._server_1.
                        debug_tensor_values["delta_1:0:DebugIdentity"])
                    self.assertAllClose(
                        [-5.0], self._server_1.
                        debug_tensor_values["delta_2:0:DebugIdentity"])
                    # After the runs, the server should have properly registered the
                    # breakpoints.
                else:
                    # After the end of runs 1 and 3, the server has received the requests
                    # to disable the breakpoint at delta:0:DebugIdentity.
                    self.assertSetEqual(set(), self._server_1.breakpoints)

                if i == 0:
                    # Check that the server has received the stack trace.
                    self.assertTrue(
                        self._server_1.query_op_traceback("delta_1"))
                    self.assertTrue(
                        self._server_1.query_op_traceback("delta_2"))
                    self.assertTrue(
                        self._server_1.query_op_traceback("inc_v_1"))
                    self.assertTrue(
                        self._server_1.query_op_traceback("inc_v_2"))
                    # Check that the server has received the python file content.
                    # Query an arbitrary line to make sure that is the case.
                    with open(__file__, "rt") as this_source_file:
                        first_line = this_source_file.readline().strip()
                    self.assertEqual(
                        first_line,
                        self._server_1.query_source_file_line(__file__, 1))
                else:
                    # In later Session.run() calls, the traceback shouldn't have been sent
                    # because it is already sent in the 1st call. So calling
                    # query_op_traceback() should lead to an exception, because the test
                    # debug server clears the data at the beginning of every iteration.
                    with self.assertRaises(ValueError):
                        self._server_1.query_op_traceback("delta_1")
                    with self.assertRaises(ValueError):
                        self._server_1.query_source_file_line(__file__, 1)