def publish_traceback(debug_server_urls, graph, feed_dict, fetches, old_graph_version): """Publish traceback and source code if graph version is new. `graph.version` is compared with `old_graph_version`. If the former is higher (i.e., newer), the graph traceback and the associated source code is sent to the debug server at the specified gRPC URLs. Args: debug_server_urls: A single gRPC debug server URL as a `str` or a `list` of debug server URLs. graph: A Python `tf.Graph` object. feed_dict: Feed dictionary given to the `Session.run()` call. fetches: Fetches from the `Session.run()` call. old_graph_version: Old graph version to compare to. Returns: If `graph.version > old_graph_version`, the new graph version as an `int`. Else, the `old_graph_version` is returned. """ # TODO(cais): Consider moving this back to the top, after grpc becomes a # pip dependency of tensorflow or tf_debug. # pylint:disable=g-import-not-at-top from tensorflow.python.debug.lib import source_remote # pylint:enable=g-import-not-at-top if graph.version > old_graph_version: run_key = common.get_run_key(feed_dict, fetches) source_remote.send_graph_tracebacks(debug_server_urls, run_key, traceback.extract_stack(), graph, send_source=True) return graph.version else: return old_graph_version
def testGetRunKeyFlat(self): a = constant_op.constant(10.0, name="a") b = constant_op.constant(20.0, name="b") run_key = common.get_run_key({"a": a}, [a, b]) loaded = json.loads(run_key) self.assertItemsEqual(["a:0"], loaded[0]) self.assertItemsEqual(["a:0", "b:0"], loaded[1])
def testOnFeedOneFetch(self): a = constant_op.constant(10.0, name="a") b = constant_op.constant(20.0, name="b") run_key = common.get_run_key({"a": a}, [b]) loaded = json.loads(run_key) self.assertItemsEqual(["a:0"], loaded[0]) self.assertItemsEqual(["b:0"], loaded[1])
def testGetRunKeyFlat(self): a = constant_op.constant(10.0, name="a") b = constant_op.constant(20.0, name="b") run_key = common.get_run_key({"a": a}, [a, b]) loaded = json.loads(run_key) self.assertItemsEqual(["a:0"], loaded[0]) self.assertItemsEqual(["a:0", "b:0"], loaded[1])
def testOnFeedOneFetch(self): a = constant_op.constant(10.0, name="a") b = constant_op.constant(20.0, name="b") run_key = common.get_run_key({"a": a}, [b]) loaded = json.loads(run_key) self.assertItemsEqual(["a:0"], loaded[0]) self.assertItemsEqual(["b:0"], loaded[1])
def testGetRunKeyNestedFetches(self): a = constant_op.constant(10.0, name="a") b = constant_op.constant(20.0, name="b") c = constant_op.constant(30.0, name="c") d = constant_op.constant(30.0, name="d") run_key = common.get_run_key( {}, {"set1": [a, b], "set2": {"c": c, "d": d}}) loaded = json.loads(run_key) self.assertItemsEqual([], loaded[0]) self.assertItemsEqual(["a:0", "b:0", "c:0", "d:0"], loaded[1])
def testGetRunKeyNestedFetches(self): a = constant_op.constant(10.0, name="a") b = constant_op.constant(20.0, name="b") c = constant_op.constant(30.0, name="c") d = constant_op.constant(30.0, name="d") run_key = common.get_run_key({}, { "set1": [a, b], "set2": { "c": c, "d": d } }) loaded = json.loads(run_key) self.assertItemsEqual([], loaded[0]) self.assertItemsEqual(["a:0", "b:0", "c:0", "d:0"], loaded[1])
def publish_traceback(debug_server_urls, graph, feed_dict, fetches, old_graph_version): """Publish traceback and source code if graph version is new. `graph.version` is compared with `old_graph_version`. If the former is higher (i.e., newer), the graph traceback and the associated source code is sent to the debug server at the specified gRPC URLs. Args: debug_server_urls: A single gRPC debug server URL as a `str` or a `list` of debug server URLs. graph: A Python `tf.Graph` object. feed_dict: Feed dictionary given to the `Session.run()` call. fetches: Fetches from the `Session.run()` call. old_graph_version: Old graph version to compare to. Returns: If `graph.version > old_graph_version`, the new graph version as an `int`. Else, the `old_graph_version` is returned. """ # TODO(cais): Consider moving this back to the top, after grpc becomes a # pip dependency of tensorflow or tf_debug. # pylint:disable=g-import-not-at-top from tensorflow.python.debug.lib import source_remote # pylint:enable=g-import-not-at-top if graph.version > old_graph_version: run_key = common.get_run_key(feed_dict, fetches) source_remote.send_graph_tracebacks( debug_server_urls, run_key, traceback.extract_stack(), graph, send_source=True) return graph.version else: return old_graph_version