def publish_traceback(debug_server_urls, graph, feed_dict, fetches,
                      old_graph_version):
    """Publish traceback and source code if graph version is new.

  `graph.version` is compared with `old_graph_version`. If the former is higher
  (i.e., newer), the graph traceback and the associated source code is sent to
  the debug server at the specified gRPC URLs.

  Args:
    debug_server_urls: A single gRPC debug server URL as a `str` or a `list` of
      debug server URLs.
    graph: A Python `tf.Graph` object.
    feed_dict: Feed dictionary given to the `Session.run()` call.
    fetches: Fetches from the `Session.run()` call.
    old_graph_version: Old graph version to compare to.

  Returns:
    If `graph.version > old_graph_version`, the new graph version as an `int`.
    Else, the `old_graph_version` is returned.
  """
    # TODO(cais): Consider moving this back to the top, after grpc becomes a
    # pip dependency of tensorflow or tf_debug.
    # pylint:disable=g-import-not-at-top
    from tensorflow.python.debug.lib import source_remote
    # pylint:enable=g-import-not-at-top
    if graph.version > old_graph_version:
        run_key = common.get_run_key(feed_dict, fetches)
        source_remote.send_graph_tracebacks(debug_server_urls,
                                            run_key,
                                            traceback.extract_stack(),
                                            graph,
                                            send_source=True)
        return graph.version
    else:
        return old_graph_version
Exemple #2
0
    def testSendGraphTracebacksToTwoDebugServers(self):
        this_func_name = "testSendGraphTracebacksToTwoDebugServers"
        with session.Session() as sess:
            a = variables.Variable(21.0, name="two/a")
            a_lineno = line_number_above()
            b = variables.Variable(2.0, name="two/b")
            b_lineno = line_number_above()
            x = math_ops.add(a, b, name="two/x")
            x_lineno = line_number_above()

            send_traceback = traceback.extract_stack()
            send_lineno = line_number_above()

            with test.mock.patch.object(
                    grpc, "insecure_channel",
                    wraps=grpc.insecure_channel) as mock_grpc_channel:
                source_remote.send_graph_tracebacks(
                    [self._server_address, self._server_address_2],
                    "dummy_run_key", send_traceback, sess.graph)
                mock_grpc_channel.assert_called_with(
                    test.mock.ANY,
                    options=[("grpc.max_receive_message_length", -1),
                             ("grpc.max_send_message_length", -1)])

            servers = [self._server, self._server_2]
            for server in servers:
                tb = server.query_op_traceback("two/a")
                self.assertIn((self._curr_file_path, a_lineno, this_func_name),
                              tb)
                tb = server.query_op_traceback("two/b")
                self.assertIn((self._curr_file_path, b_lineno, this_func_name),
                              tb)
                tb = server.query_op_traceback("two/x")
                self.assertIn((self._curr_file_path, x_lineno, this_func_name),
                              tb)

                self.assertIn(
                    (self._curr_file_path, send_lineno, this_func_name),
                    server.query_origin_stack()[-1])

                self.assertEqual(
                    "      x = math_ops.add(a, b, name=\"two/x\")",
                    server.query_source_file_line(__file__, x_lineno))
                tf_trace = self._findFirstTraceInsideTensorFlowPyLibrary(a.op)
                tf_trace_file_path = tf_trace.filename
                with self.assertRaises(ValueError):
                    server.query_source_file_line(tf_trace_file_path, 0)
                self.assertEqual(
                    [debug_service_pb2.CallTraceback.GRAPH_EXECUTION],
                    server.query_call_types())
                self.assertEqual(["dummy_run_key"], server.query_call_keys())
                self.assertEqual([sess.graph.version],
                                 server.query_graph_versions())
    def testSourceFileSizeExceedsGrpcMessageLengthLimit(self):
        """In case source file size exceeds the grpc message length limit.

    it ought not to have been sent to the server.
    """
        this_func_name = "testSourceFileSizeExceedsGrpcMessageLengthLimit"

        # Patch the method to simulate a very small message length limit.
        with test.mock.patch.object(source_remote,
                                    "grpc_message_length_bytes",
                                    return_value=2):
            with session.Session() as sess:
                a = variables.Variable(21.0, name="two/a")
                a_lineno = line_number_above()
                b = variables.Variable(2.0, name="two/b")
                b_lineno = line_number_above()
                x = math_ops.add(a, b, name="two/x")
                x_lineno = line_number_above()

                send_traceback = traceback.extract_stack()
                send_lineno = line_number_above()
                source_remote.send_graph_tracebacks(
                    [self._server_address, self._server_address_2],
                    "dummy_run_key", send_traceback, sess.graph)

                servers = [self._server, self._server_2]
                for server in servers:
                    # Even though the source file content is not sent, the traceback
                    # should have been sent.
                    tb = server.query_op_traceback("two/a")
                    self.assertIn(
                        (self._curr_file_path, a_lineno, this_func_name), tb)
                    tb = server.query_op_traceback("two/b")
                    self.assertIn(
                        (self._curr_file_path, b_lineno, this_func_name), tb)
                    tb = server.query_op_traceback("two/x")
                    self.assertIn(
                        (self._curr_file_path, x_lineno, this_func_name), tb)

                    self.assertIn(
                        (self._curr_file_path, send_lineno, this_func_name),
                        server.query_origin_stack()[-1])

                    tf_trace_file_path = (
                        self._findFirstTraceInsideTensorFlowPyLibrary(x.op))
                    # Verify that the source content is not sent to the server.
                    with self.assertRaises(ValueError):
                        self._server.query_source_file_line(
                            tf_trace_file_path, 0)
    def testSendGraphTracebacksToTwoDebugServers(self):
        this_func_name = "testSendGraphTracebacksToTwoDebugServers"
        with session.Session() as sess:
            a = variables.Variable(21.0, name="two/a")
            a_lineno = line_number_above()
            b = variables.Variable(2.0, name="two/b")
            b_lineno = line_number_above()
            x = math_ops.add(a, b, name="two/x")
            x_lineno = line_number_above()

            send_traceback = traceback.extract_stack()
            send_lineno = line_number_above()
            source_remote.send_graph_tracebacks(
                [self._server_address, self._server_address_2],
                "dummy_run_key", send_traceback, sess.graph)

            servers = [self._server, self._server_2]
            for server in servers:
                tb = server.query_op_traceback("two/a")
                self.assertIn((self._curr_file_path, a_lineno, this_func_name),
                              tb)
                tb = server.query_op_traceback("two/b")
                self.assertIn((self._curr_file_path, b_lineno, this_func_name),
                              tb)
                tb = server.query_op_traceback("two/x")
                self.assertIn((self._curr_file_path, x_lineno, this_func_name),
                              tb)

                self.assertIn(
                    (self._curr_file_path, send_lineno, this_func_name),
                    server.query_origin_stack()[-1])

                self.assertEqual(
                    "      x = math_ops.add(a, b, name=\"two/x\")",
                    server.query_source_file_line(__file__, x_lineno))
                tf_trace_file_path = self._findFirstTraceInsideTensorFlowPyLibrary(
                    x.op)
                with self.assertRaises(ValueError):
                    server.query_source_file_line(tf_trace_file_path, 0)
                self.assertEqual(
                    [debug_service_pb2.CallTraceback.GRAPH_EXECUTION],
                    server.query_call_types())
                self.assertEqual(["dummy_run_key"], server.query_call_keys())
                self.assertEqual([sess.graph.version],
                                 server.query_graph_versions())
  def testSourceFileSizeExceedsGrpcMessageLengthLimit(self):
    """In case source file size exceeds the grpc message length limit.

    it ought not to have been sent to the server.
    """
    this_func_name = "testSourceFileSizeExceedsGrpcMessageLengthLimit"

    # Patch the method to simulate a very small message length limit.
    with test.mock.patch.object(
        source_remote, "grpc_message_length_bytes", return_value=2):
      with session.Session() as sess:
        a = variables.Variable(21.0, name="two/a")
        a_lineno = line_number_above()
        b = variables.Variable(2.0, name="two/b")
        b_lineno = line_number_above()
        x = math_ops.add(a, b, name="two/x")
        x_lineno = line_number_above()

        send_traceback = traceback.extract_stack()
        send_lineno = line_number_above()
        source_remote.send_graph_tracebacks(
            [self._server_address, self._server_address_2],
            "dummy_run_key", send_traceback, sess.graph)

        servers = [self._server, self._server_2]
        for server in servers:
          # Even though the source file content is not sent, the traceback
          # should have been sent.
          tb = server.query_op_traceback("two/a")
          self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb)
          tb = server.query_op_traceback("two/b")
          self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb)
          tb = server.query_op_traceback("two/x")
          self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb)

          self.assertIn(
              (self._curr_file_path, send_lineno, this_func_name),
              server.query_origin_stack()[-1])

          tf_trace_file_path = (
              self._findFirstTraceInsideTensorFlowPyLibrary(x.op))
          # Verify that the source content is not sent to the server.
          with self.assertRaises(ValueError):
            self._server.query_source_file_line(tf_trace_file_path, 0)
Exemple #6
0
    def testSendGraphTracebacksToSingleDebugServer(self):
        this_func_name = "testSendGraphTracebacksToSingleDebugServer"
        with session.Session() as sess:
            a = variables.Variable(21.0, name="a")
            a_lineno = line_number_above()
            b = variables.Variable(2.0, name="b")
            b_lineno = line_number_above()
            math_ops.add(a, b, name="x")
            x_lineno = line_number_above()

            send_stack = traceback.extract_stack()
            send_lineno = line_number_above()
            source_remote.send_graph_tracebacks(self._server_address,
                                                "dummy_run_key", send_stack,
                                                sess.graph)

            tb = self._server.query_op_traceback("a")
            self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb)
            tb = self._server.query_op_traceback("b")
            self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb)
            tb = self._server.query_op_traceback("x")
            self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb)

            self.assertIn((self._curr_file_path, send_lineno, this_func_name),
                          self._server.query_origin_stack()[-1])

            self.assertEqual(
                "      a = variables.Variable(21.0, name=\"a\")",
                self._server.query_source_file_line(__file__, a_lineno))
            # Files in the TensorFlow code base shouldn not have been sent.
            tf_trace = self._findFirstTraceInsideTensorFlowPyLibrary(a.op)
            tf_trace_file_path = tf_trace.filename
            with self.assertRaises(ValueError):
                self._server.query_source_file_line(tf_trace_file_path, 0)
            self.assertEqual([debug_service_pb2.CallTraceback.GRAPH_EXECUTION],
                             self._server.query_call_types())
            self.assertEqual(["dummy_run_key"], self._server.query_call_keys())
            self.assertEqual([sess.graph.version],
                             self._server.query_graph_versions())
  def testSendGraphTracebacksToTwoDebugServers(self):
    this_func_name = "testSendGraphTracebacksToTwoDebugServers"
    with session.Session() as sess:
      a = variables.Variable(21.0, name="two/a")
      a_lineno = line_number_above()
      b = variables.Variable(2.0, name="two/b")
      b_lineno = line_number_above()
      x = math_ops.add(a, b, name="two/x")
      x_lineno = line_number_above()

      send_traceback = traceback.extract_stack()
      send_lineno = line_number_above()
      source_remote.send_graph_tracebacks(
          [self._server_address, self._server_address_2],
          "dummy_run_key", send_traceback, sess.graph)

      servers = [self._server, self._server_2]
      for server in servers:
        tb = server.query_op_traceback("two/a")
        self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb)
        tb = server.query_op_traceback("two/b")
        self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb)
        tb = server.query_op_traceback("two/x")
        self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb)

        self.assertIn(
            (self._curr_file_path, send_lineno, this_func_name),
            server.query_origin_stack()[-1])

        self.assertEqual(
            "      x = math_ops.add(a, b, name=\"two/x\")",
            server.query_source_file_line(__file__, x_lineno))
        tf_trace_file_path = self._findFirstTraceInsideTensorFlowPyLibrary(x.op)
        with self.assertRaises(ValueError):
          server.query_source_file_line(tf_trace_file_path, 0)
        self.assertEqual([debug_service_pb2.CallTraceback.GRAPH_EXECUTION],
                         server.query_call_types())
        self.assertEqual(["dummy_run_key"], server.query_call_keys())
        self.assertEqual([sess.graph.version], server.query_graph_versions())
  def testSendGraphTracebacksToSingleDebugServer(self):
    this_func_name = "testSendGraphTracebacksToSingleDebugServer"
    with session.Session() as sess:
      a = variables.Variable(21.0, name="a")
      a_lineno = line_number_above()
      b = variables.Variable(2.0, name="b")
      b_lineno = line_number_above()
      math_ops.add(a, b, name="x")
      x_lineno = line_number_above()

      send_stack = traceback.extract_stack()
      send_lineno = line_number_above()
      source_remote.send_graph_tracebacks(
          self._server_address, "dummy_run_key", send_stack, sess.graph)

      tb = self._server.query_op_traceback("a")
      self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb)
      tb = self._server.query_op_traceback("b")
      self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb)
      tb = self._server.query_op_traceback("x")
      self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb)

      self.assertIn(
          (self._curr_file_path, send_lineno, this_func_name),
          self._server.query_origin_stack()[-1])

      self.assertEqual(
          "      a = variables.Variable(21.0, name=\"a\")",
          self._server.query_source_file_line(__file__, a_lineno))
      # Files in the TensorFlow code base shouldn not have been sent.
      tf_trace_file_path = self._findFirstTraceInsideTensorFlowPyLibrary(a.op)
      with self.assertRaises(ValueError):
        self._server.query_source_file_line(tf_trace_file_path, 0)
      self.assertEqual([debug_service_pb2.CallTraceback.GRAPH_EXECUTION],
                       self._server.query_call_types())
      self.assertEqual(["dummy_run_key"], self._server.query_call_keys())
      self.assertEqual(
          [sess.graph.version], self._server.query_graph_versions())
Exemple #9
0
def publish_traceback(debug_server_urls,
                      graph,
                      feed_dict,
                      fetches,
                      old_graph_version):
  """Publish traceback and source code if graph version is new.

  `graph.version` is compared with `old_graph_version`. If the former is higher
  (i.e., newer), the graph traceback and the associated source code is sent to
  the debug server at the specified gRPC URLs.

  Args:
    debug_server_urls: A single gRPC debug server URL as a `str` or a `list` of
      debug server URLs.
    graph: A Python `tf.Graph` object.
    feed_dict: Feed dictionary given to the `Session.run()` call.
    fetches: Fetches from the `Session.run()` call.
    old_graph_version: Old graph version to compare to.

  Returns:
    If `graph.version > old_graph_version`, the new graph version as an `int`.
    Else, the `old_graph_version` is returned.
  """
  # TODO(cais): Consider moving this back to the top, after grpc becomes a
  # pip dependency of tensorflow or tf_debug.
  # pylint:disable=g-import-not-at-top
  from tensorflow.python.debug.lib import source_remote
  # pylint:enable=g-import-not-at-top
  if graph.version > old_graph_version:
    run_key = common.get_run_key(feed_dict, fetches)
    source_remote.send_graph_tracebacks(
        debug_server_urls, run_key, traceback.extract_stack(), graph,
        send_source=True)
    return graph.version
  else:
    return old_graph_version