def testExtractGatedGrpcTensorsFoundGatedGrpcOps(self): with tf.compat.v1.Session() as sess: z, run_options = self._createTestGraphAndRunOptions( sess, gated_grpc=True) sess.run(tf.compat.v1.global_variables_initializer()) run_metadata = config_pb2.RunMetadata() self.assertAllClose( [10.0], sess.run(z, options=run_options, run_metadata=run_metadata), ) graph_wrapper = debug_graphs_helper.DebugGraphWrapper( run_metadata.partition_graphs[0]) gated_debug_ops = graph_wrapper.get_gated_grpc_tensors() # Verify that the op types are available. for item in gated_debug_ops: self.assertTrue(item[1]) # Strip out the op types before further checks, because op type names can # change in the future (e.g., 'VariableV2' --> 'VariableV3'). gated_debug_ops = [(item[0], item[2], item[3]) for item in gated_debug_ops] self.assertIn(("a", 0, "DebugIdentity"), gated_debug_ops) self.assertIn(("b", 0, "DebugIdentity"), gated_debug_ops) self.assertIn(("c", 0, "DebugIdentity"), gated_debug_ops) self.assertIn(("d", 0, "DebugIdentity"), gated_debug_ops) self.assertIn(("x", 0, "DebugIdentity"), gated_debug_ops) self.assertIn(("y", 0, "DebugIdentity"), gated_debug_ops) self.assertIn(("z", 0, "DebugIdentity"), gated_debug_ops)
def testMaybeBaseExpandedNodeName(self): with tf.compat.v1.Session() as sess: a = tf.Variable([1.0], name='foo/a') b = tf.Variable([2.0], name='bar/b') _ = tf.add(a, b, name='baz/c') graph_wrapper = debug_graphs_helper.DebugGraphWrapper(sess.graph_def) self.assertEqual( 'foo/a/(a)', graph_wrapper.maybe_base_expanded_node_name('foo/a')) self.assertEqual( 'bar/b/(b)', graph_wrapper.maybe_base_expanded_node_name('bar/b')) self.assertEqual( 'foo/a/read', graph_wrapper.maybe_base_expanded_node_name('foo/a/read')) self.assertEqual( 'bar/b/read', graph_wrapper.maybe_base_expanded_node_name('bar/b/read')) if tensorflow_python_tf2.enabled(): # NOTE(#1705): TF 2.0 tf.add creates nested nodes. self.assertEqual( 'baz/c/(c)', graph_wrapper.maybe_base_expanded_node_name('baz/c')) else: self.assertEqual( 'baz/c', graph_wrapper.maybe_base_expanded_node_name('baz/c'))
def testGraphDefProperty(self): with tf.compat.v1.Session() as sess: z, run_options = self._createTestGraphAndRunOptions(sess, gated_grpc=True) sess.run(tf.compat.v1.global_variables_initializer()) run_metadata = config_pb2.RunMetadata() self.assertAllClose( [10.0], sess.run(z, options=run_options, run_metadata=run_metadata)) graph_wrapper = debug_graphs_helper.DebugGraphWrapper( run_metadata.partition_graphs[0]) self.assertProtoEquals( run_metadata.partition_graphs[0], graph_wrapper.graph_def)
def testExtractGatedGrpcTensorsFoundNoGatedGrpcOps(self): with tf.compat.v1.Session() as sess: z, run_options = self._createTestGraphAndRunOptions(sess, gated_grpc=False) sess.run(tf.compat.v1.global_variables_initializer()) run_metadata = config_pb2.RunMetadata() self.assertAllClose( [10.0], sess.run(z, options=run_options, run_metadata=run_metadata)) graph_wrapper = debug_graphs_helper.DebugGraphWrapper( run_metadata.partition_graphs[0]) gated_debug_ops = graph_wrapper.get_gated_grpc_tensors() self.assertEqual([], gated_debug_ops)
def add_graph(self, run_key, device_name, graph_def, debug=False): """Add a GraphDef. Args: run_key: A key for the run, containing information about the feeds, fetches, and targets. device_name: The name of the device that the `GraphDef` is for. graph_def: An instance of the `GraphDef` proto. debug: Whether `graph_def` consists of the debug ops. """ graph_dict = (self._run_key_to_debug_graphs if debug else self._run_key_to_original_graphs) if not run_key in graph_dict: graph_dict[run_key] = dict() # Mapping device_name to GraphDef. graph_dict[run_key][tf.compat.as_str(device_name)] = ( debug_graphs_helper.DebugGraphWrapper(graph_def))
def testMaybeBaseExpandedNodeName(self): with tf.Session() as sess: a = tf.Variable([1.0], name='foo/a') b = tf.Variable([2.0], name='bar/b') _ = tf.add(a, b, name='baz/c') graph_wrapper = debug_graphs_helper.DebugGraphWrapper( sess.graph_def) self.assertEqual( 'foo/a/(a)', graph_wrapper.maybe_base_expanded_node_name('foo/a')) self.assertEqual( 'bar/b/(b)', graph_wrapper.maybe_base_expanded_node_name('bar/b')) self.assertEqual( 'foo/a/read', graph_wrapper.maybe_base_expanded_node_name('foo/a/read')) self.assertEqual( 'bar/b/read', graph_wrapper.maybe_base_expanded_node_name('bar/b/read')) self.assertEqual( 'baz/c', graph_wrapper.maybe_base_expanded_node_name('baz/c'))