def testExtractGatedGrpcTensorsFoundGatedGrpcOps(self):
        with tf.compat.v1.Session() as sess:
            z, run_options = self._createTestGraphAndRunOptions(
                sess, gated_grpc=True)

            sess.run(tf.compat.v1.global_variables_initializer())
            run_metadata = config_pb2.RunMetadata()
            self.assertAllClose(
                [10.0],
                sess.run(z, options=run_options, run_metadata=run_metadata),
            )

            graph_wrapper = debug_graphs_helper.DebugGraphWrapper(
                run_metadata.partition_graphs[0])
            gated_debug_ops = graph_wrapper.get_gated_grpc_tensors()

            # Verify that the op types are available.
            for item in gated_debug_ops:
                self.assertTrue(item[1])

            # Strip out the op types before further checks, because op type names can
            # change in the future (e.g., 'VariableV2' --> 'VariableV3').
            gated_debug_ops = [(item[0], item[2], item[3])
                               for item in gated_debug_ops]

            self.assertIn(("a", 0, "DebugIdentity"), gated_debug_ops)
            self.assertIn(("b", 0, "DebugIdentity"), gated_debug_ops)
            self.assertIn(("c", 0, "DebugIdentity"), gated_debug_ops)
            self.assertIn(("d", 0, "DebugIdentity"), gated_debug_ops)

            self.assertIn(("x", 0, "DebugIdentity"), gated_debug_ops)
            self.assertIn(("y", 0, "DebugIdentity"), gated_debug_ops)
            self.assertIn(("z", 0, "DebugIdentity"), gated_debug_ops)
  def testMaybeBaseExpandedNodeName(self):
    with tf.compat.v1.Session() as sess:
      a = tf.Variable([1.0], name='foo/a')
      b = tf.Variable([2.0], name='bar/b')
      _ = tf.add(a, b, name='baz/c')

      graph_wrapper = debug_graphs_helper.DebugGraphWrapper(sess.graph_def)

      self.assertEqual(
          'foo/a/(a)', graph_wrapper.maybe_base_expanded_node_name('foo/a'))
      self.assertEqual(
          'bar/b/(b)', graph_wrapper.maybe_base_expanded_node_name('bar/b'))
      self.assertEqual(
          'foo/a/read',
          graph_wrapper.maybe_base_expanded_node_name('foo/a/read'))
      self.assertEqual(
          'bar/b/read',
          graph_wrapper.maybe_base_expanded_node_name('bar/b/read'))

      if tensorflow_python_tf2.enabled():
        # NOTE(#1705): TF 2.0 tf.add creates nested nodes.
        self.assertEqual(
            'baz/c/(c)', graph_wrapper.maybe_base_expanded_node_name('baz/c'))
      else:
        self.assertEqual(
            'baz/c', graph_wrapper.maybe_base_expanded_node_name('baz/c'))
  def testGraphDefProperty(self):
    with tf.compat.v1.Session() as sess:
      z, run_options = self._createTestGraphAndRunOptions(sess, gated_grpc=True)

      sess.run(tf.compat.v1.global_variables_initializer())
      run_metadata = config_pb2.RunMetadata()
      self.assertAllClose(
          [10.0], sess.run(z, options=run_options, run_metadata=run_metadata))

      graph_wrapper = debug_graphs_helper.DebugGraphWrapper(
          run_metadata.partition_graphs[0])
      self.assertProtoEquals(
          run_metadata.partition_graphs[0], graph_wrapper.graph_def)
  def testExtractGatedGrpcTensorsFoundNoGatedGrpcOps(self):
    with tf.compat.v1.Session() as sess:
      z, run_options = self._createTestGraphAndRunOptions(sess,
                                                          gated_grpc=False)

      sess.run(tf.compat.v1.global_variables_initializer())
      run_metadata = config_pb2.RunMetadata()
      self.assertAllClose(
          [10.0], sess.run(z, options=run_options, run_metadata=run_metadata))

      graph_wrapper = debug_graphs_helper.DebugGraphWrapper(
          run_metadata.partition_graphs[0])
      gated_debug_ops = graph_wrapper.get_gated_grpc_tensors()
      self.assertEqual([], gated_debug_ops)
    def add_graph(self, run_key, device_name, graph_def, debug=False):
        """Add a GraphDef.

    Args:
      run_key: A key for the run, containing information about the feeds,
        fetches, and targets.
      device_name: The name of the device that the `GraphDef` is for.
      graph_def: An instance of the `GraphDef` proto.
      debug: Whether `graph_def` consists of the debug ops.
    """
        graph_dict = (self._run_key_to_debug_graphs
                      if debug else self._run_key_to_original_graphs)
        if not run_key in graph_dict:
            graph_dict[run_key] = dict()  # Mapping device_name to GraphDef.
        graph_dict[run_key][tf.compat.as_str(device_name)] = (
            debug_graphs_helper.DebugGraphWrapper(graph_def))
    def testMaybeBaseExpandedNodeName(self):
        with tf.Session() as sess:
            a = tf.Variable([1.0], name='foo/a')
            b = tf.Variable([2.0], name='bar/b')
            _ = tf.add(a, b, name='baz/c')

            graph_wrapper = debug_graphs_helper.DebugGraphWrapper(
                sess.graph_def)

            self.assertEqual(
                'foo/a/(a)',
                graph_wrapper.maybe_base_expanded_node_name('foo/a'))
            self.assertEqual(
                'bar/b/(b)',
                graph_wrapper.maybe_base_expanded_node_name('bar/b'))
            self.assertEqual(
                'foo/a/read',
                graph_wrapper.maybe_base_expanded_node_name('foo/a/read'))
            self.assertEqual(
                'bar/b/read',
                graph_wrapper.maybe_base_expanded_node_name('bar/b/read'))
            self.assertEqual(
                'baz/c', graph_wrapper.maybe_base_expanded_node_name('baz/c'))