コード例 #1
0
def publish_traceback(debug_server_urls, graph, feed_dict, fetches,
                      old_graph_version):
    """Publish traceback and source code if graph version is new.

  `graph.version` is compared with `old_graph_version`. If the former is higher
  (i.e., newer), the graph traceback and the associated source code is sent to
  the debug server at the specified gRPC URLs.

  Args:
    debug_server_urls: A single gRPC debug server URL as a `str` or a `list` of
      debug server URLs.
    graph: A Python `tf.Graph` object.
    feed_dict: Feed dictionary given to the `Session.run()` call.
    fetches: Fetches from the `Session.run()` call.
    old_graph_version: Old graph version to compare to.

  Returns:
    If `graph.version > old_graph_version`, the new graph version as an `int`.
    Else, the `old_graph_version` is returned.
  """
    # TODO(cais): Consider moving this back to the top, after grpc becomes a
    # pip dependency of tensorflow or tf_debug.
    # pylint:disable=g-import-not-at-top
    from tensorflow.python.debug.lib import source_remote
    # pylint:enable=g-import-not-at-top
    if graph.version > old_graph_version:
        run_key = common.get_run_key(feed_dict, fetches)
        source_remote.send_graph_tracebacks(debug_server_urls,
                                            run_key,
                                            traceback.extract_stack(),
                                            graph,
                                            send_source=True)
        return graph.version
    else:
        return old_graph_version
コード例 #2
0
 def testGetRunKeyFlat(self):
     a = constant_op.constant(10.0, name="a")
     b = constant_op.constant(20.0, name="b")
     run_key = common.get_run_key({"a": a}, [a, b])
     loaded = json.loads(run_key)
     self.assertItemsEqual(["a:0"], loaded[0])
     self.assertItemsEqual(["a:0", "b:0"], loaded[1])
コード例 #3
0
 def testOnFeedOneFetch(self):
     a = constant_op.constant(10.0, name="a")
     b = constant_op.constant(20.0, name="b")
     run_key = common.get_run_key({"a": a}, [b])
     loaded = json.loads(run_key)
     self.assertItemsEqual(["a:0"], loaded[0])
     self.assertItemsEqual(["b:0"], loaded[1])
コード例 #4
0
ファイル: common_test.py プロジェクト: Wajih-O/tensorflow
 def testGetRunKeyFlat(self):
   a = constant_op.constant(10.0, name="a")
   b = constant_op.constant(20.0, name="b")
   run_key = common.get_run_key({"a": a}, [a, b])
   loaded = json.loads(run_key)
   self.assertItemsEqual(["a:0"], loaded[0])
   self.assertItemsEqual(["a:0", "b:0"], loaded[1])
コード例 #5
0
ファイル: common_test.py プロジェクト: Wajih-O/tensorflow
 def testOnFeedOneFetch(self):
   a = constant_op.constant(10.0, name="a")
   b = constant_op.constant(20.0, name="b")
   run_key = common.get_run_key({"a": a}, [b])
   loaded = json.loads(run_key)
   self.assertItemsEqual(["a:0"], loaded[0])
   self.assertItemsEqual(["b:0"], loaded[1])
コード例 #6
0
ファイル: common_test.py プロジェクト: Wajih-O/tensorflow
 def testGetRunKeyNestedFetches(self):
   a = constant_op.constant(10.0, name="a")
   b = constant_op.constant(20.0, name="b")
   c = constant_op.constant(30.0, name="c")
   d = constant_op.constant(30.0, name="d")
   run_key = common.get_run_key(
       {}, {"set1": [a, b], "set2": {"c": c, "d": d}})
   loaded = json.loads(run_key)
   self.assertItemsEqual([], loaded[0])
   self.assertItemsEqual(["a:0", "b:0", "c:0", "d:0"], loaded[1])
コード例 #7
0
 def testGetRunKeyNestedFetches(self):
     a = constant_op.constant(10.0, name="a")
     b = constant_op.constant(20.0, name="b")
     c = constant_op.constant(30.0, name="c")
     d = constant_op.constant(30.0, name="d")
     run_key = common.get_run_key({}, {
         "set1": [a, b],
         "set2": {
             "c": c,
             "d": d
         }
     })
     loaded = json.loads(run_key)
     self.assertItemsEqual([], loaded[0])
     self.assertItemsEqual(["a:0", "b:0", "c:0", "d:0"], loaded[1])
コード例 #8
0
ファイル: grpc_wrapper.py プロジェクト: AnishShah/tensorflow
def publish_traceback(debug_server_urls,
                      graph,
                      feed_dict,
                      fetches,
                      old_graph_version):
  """Publish traceback and source code if graph version is new.

  `graph.version` is compared with `old_graph_version`. If the former is higher
  (i.e., newer), the graph traceback and the associated source code is sent to
  the debug server at the specified gRPC URLs.

  Args:
    debug_server_urls: A single gRPC debug server URL as a `str` or a `list` of
      debug server URLs.
    graph: A Python `tf.Graph` object.
    feed_dict: Feed dictionary given to the `Session.run()` call.
    fetches: Fetches from the `Session.run()` call.
    old_graph_version: Old graph version to compare to.

  Returns:
    If `graph.version > old_graph_version`, the new graph version as an `int`.
    Else, the `old_graph_version` is returned.
  """
  # TODO(cais): Consider moving this back to the top, after grpc becomes a
  # pip dependency of tensorflow or tf_debug.
  # pylint:disable=g-import-not-at-top
  from tensorflow.python.debug.lib import source_remote
  # pylint:enable=g-import-not-at-top
  if graph.version > old_graph_version:
    run_key = common.get_run_key(feed_dict, fetches)
    source_remote.send_graph_tracebacks(
        debug_server_urls, run_key, traceback.extract_stack(), graph,
        send_source=True)
    return graph.version
  else:
    return old_graph_version