def testTensorBoardDebuggerWrapperDisablingTracebackSourceSendingWorks(self): with session.Session(config=no_rewrite_session_config()) as sess: v_1 = variables.Variable(50.0, name="v_1") v_2 = variables.Variable(-50.0, name="v_2") delta_1 = constant_op.constant(5.0, name="delta_1") delta_2 = constant_op.constant(-5.0, name="delta_2") inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1") inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2") sess.run(variables.global_variables_initializer()) # Disable the sending of traceback and source code. sess = grpc_wrapper.TensorBoardDebugWrapperSession( sess, self._debug_server_url_1, send_traceback_and_source_code=False) for i in xrange(4): self._server_1.clear_data() if i == 0: self._server_1.request_watch( "delta_1", 0, "DebugIdentity", breakpoint=True) output = sess.run([inc_v_1, inc_v_2]) self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output) # No op traceback or source code should have been received by the debug # server due to the disabling above. with self.assertRaisesRegexp( ValueError, r"Op .*delta_1.* does not exist"): self.assertTrue(self._server_1.query_op_traceback("delta_1")) with self.assertRaisesRegexp( ValueError, r".* has not received any source file"): self._server_1.query_source_file_line(__file__, 1)
def testTensorBoardDebuggerWrapperToggleBreakpointsWorks(self): with session.Session(config=no_rewrite_session_config()) as sess: v_1 = variables.Variable(50.0, name="v_1") v_2 = variables.Variable(-50.0, name="v_2") delta_1 = constant_op.constant(5.0, name="delta_1") delta_2 = constant_op.constant(-5.0, name="delta_2") inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1") inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2") sess.run([v_1.initializer, v_2.initializer]) # The TensorBoardDebugWrapperSession should add a DebugIdentity debug op # with attribute gated_grpc=True for every tensor in the graph. sess = grpc_wrapper.TensorBoardDebugWrapperSession( sess, self._debug_server_url_1) for i in xrange(4): self._server_1.clear_data() if i in (0, 2): # Enable breakpoint at delta_[1,2]:0:DebugIdentity in runs 0 and 2. self._server_1.request_watch("delta_1", 0, "DebugIdentity", breakpoint=True) self._server_1.request_watch("delta_2", 0, "DebugIdentity", breakpoint=True) else: # Disable the breakpoint in runs 1 and 3. self._server_1.request_unwatch("delta_1", 0, "DebugIdentity") self._server_1.request_unwatch("delta_2", 0, "DebugIdentity") output = sess.run([inc_v_1, inc_v_2]) self.assertAllClose( [50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output) if i in (0, 2): # During runs 0 and 2, the server should have received the published # debug tensor delta:0:DebugIdentity. The breakpoint should have been # unblocked by EventReply reponses from the server. self.assertAllClose( [5.0], self._server_1. debug_tensor_values["delta_1:0:DebugIdentity"]) self.assertAllClose( [-5.0], self._server_1. debug_tensor_values["delta_2:0:DebugIdentity"]) # After the runs, the server should have properly registered the # breakpoints. else: # After the end of runs 1 and 3, the server has received the requests # to disable the breakpoint at delta:0:DebugIdentity. self.assertSetEqual(set(), self._server_1.breakpoints)
def testTensorBoardDebuggerWrapperToggleBreakpointsWorks(self): with session.Session(config=session_debug_testlib. no_rewrite_session_config()) as sess: v_1 = variables.Variable(50.0, name="v_1") v_2 = variables.Variable(-50.0, name="v_2") delta_1 = constant_op.constant(5.0, name="delta_1") delta_2 = constant_op.constant(-5.0, name="delta_2") inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1") inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2") sess.run([v_1.initializer, v_2.initializer]) # The TensorBoardDebugWrapperSession should add a DebugIdentity debug op # with attribute gated_grpc=True for every tensor in the graph. sess = grpc_wrapper.TensorBoardDebugWrapperSession( sess, self._debug_server_url_1) for i in xrange(4): self._server_1.clear_data() if i in (0, 2): # Enable breakpoint at delta_[1,2]:0:DebugIdentity in runs 0 and 2. self._server_1.request_watch("delta_1", 0, "DebugIdentity", breakpoint=True) self._server_1.request_watch("delta_2", 0, "DebugIdentity", breakpoint=True) else: # Disable the breakpoint in runs 1 and 3. self._server_1.request_unwatch("delta_1", 0, "DebugIdentity") self._server_1.request_unwatch("delta_2", 0, "DebugIdentity") output = sess.run([inc_v_1, inc_v_2]) self.assertAllClose( [50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output) if i in (0, 2): # During runs 0 and 2, the server should have received the published # debug tensor delta:0:DebugIdentity. The breakpoint should have been # unblocked by EventReply reponses from the server. self.assertAllClose( [5.0], self._server_1. debug_tensor_values["delta_1:0:DebugIdentity"]) self.assertAllClose( [-5.0], self._server_1. debug_tensor_values["delta_2:0:DebugIdentity"]) # After the runs, the server should have properly registered the # breakpoints. else: # After the end of runs 1 and 3, the server has received the requests # to disable the breakpoint at delta:0:DebugIdentity. self.assertSetEqual(set(), self._server_1.breakpoints) if i == 0: # Check that the server has received the stack trace. self.assertTrue( self._server_1.query_op_traceback("delta_1")) self.assertTrue( self._server_1.query_op_traceback("delta_2")) self.assertTrue( self._server_1.query_op_traceback("inc_v_1")) self.assertTrue( self._server_1.query_op_traceback("inc_v_2")) # Check that the server has received the python file content. # Query an arbitrary line to make sure that is the case. with open(__file__, "rt") as this_source_file: first_line = this_source_file.readline().strip() self.assertEqual( first_line, self._server_1.query_source_file_line(__file__, 1)) else: # In later Session.run() calls, the traceback shouldn't have been sent # because it is already sent in the 1st call. So calling # query_op_traceback() should lead to an exception, because the test # debug server clears the data at the beginning of every iteration. with self.assertRaises(ValueError): self._server_1.query_op_traceback("delta_1") with self.assertRaises(ValueError): self._server_1.query_source_file_line(__file__, 1)