Esempio n. 1
0
 def _process_graph_def(self, graph_def):
   for node_def in graph_def.node:
     if (debug_graphs.is_debug_node(node_def.name) and
         node_def.attr["gated_grpc"].b):
       node_name, output_slot, _, debug_op = (
           debug_graphs.parse_debug_node_name(node_def.name))
       self._gated_grpc_debug_watches.add(
           DebugWatch(node_name, output_slot, debug_op))
Esempio n. 2
0
    def testParseDebugNodeName_valid(self):
        debug_node_name_1 = "__dbg_ns_a/ns_b/node_c:1_0_DebugIdentity"
        (watched_node, watched_output_slot, debug_op_index,
         debug_op) = debug_graphs.parse_debug_node_name(debug_node_name_1)

        self.assertEqual("ns_a/ns_b/node_c", watched_node)
        self.assertEqual(1, watched_output_slot)
        self.assertEqual(0, debug_op_index)
        self.assertEqual("DebugIdentity", debug_op)
Esempio n. 3
0
    def get_gated_grpc_tensors(self, matching_debug_op=None):
        """Extract all nodes with gated-gRPC debug ops attached.

        Uses cached values if available.
        This method is thread-safe.

        Args:
          graph_def: A tf.GraphDef proto.
          matching_debug_op: Return tensors and nodes with only matching the
            specified debug op name (optional). If `None`, will extract only
            `DebugIdentity` debug ops.

        Returns:
          A list of (node_name, op_type, output_slot, debug_op) tuples.
        """
        with self._grpc_gated_lock:
            matching_debug_op = matching_debug_op or "DebugIdentity"
            if matching_debug_op not in self._grpc_gated_tensors:
                # First, construct a map from node name to op type.
                node_name_to_op_type = dict(
                    (node.name, node.op) for node in self._graph_def.node)

                # Second, populate the output list.
                gated = []
                for node in self._graph_def.node:
                    if node.op == matching_debug_op:
                        for attr_key in node.attr:
                            if (attr_key == "gated_grpc"
                                    and node.attr[attr_key].b):
                                (
                                    node_name,
                                    output_slot,
                                    _,
                                    debug_op,
                                ) = debug_graphs.parse_debug_node_name(
                                    node.name)
                                gated.append((
                                    node_name,
                                    node_name_to_op_type[node_name],
                                    output_slot,
                                    debug_op,
                                ))
                                break
                self._grpc_gated_tensors[matching_debug_op] = gated

            return self._grpc_gated_tensors[matching_debug_op]
  def get_gated_grpc_tensors(self, matching_debug_op=None):
    """Extract all nodes with gated-gRPC debug ops attached.

    Uses cached values if available.
    This method is thread-safe.

    Args:
      graph_def: A tf.GraphDef proto.
      matching_debug_op: Return tensors and nodes with only matching the
        specified debug op name (optional). If `None`, will extract only
        `DebugIdentity` debug ops.

    Returns:
      A list of (node_name, op_type, output_slot, debug_op) tuples.
    """
    with self._grpc_gated_lock:
      matching_debug_op = matching_debug_op or 'DebugIdentity'
      if matching_debug_op not in self._grpc_gated_tensors:
        # First, construct a map from node name to op type.
        node_name_to_op_type = dict(
            (node.name, node.op) for node in self._graph_def.node)

        # Second, populate the output list.
        gated = []
        for node in self._graph_def.node:
          if node.op == matching_debug_op:
            for attr_key in node.attr:
              if attr_key == 'gated_grpc' and node.attr[attr_key].b:
                node_name, output_slot, _, debug_op = (
                    debug_graphs.parse_debug_node_name(node.name))
                gated.append(
                    (node_name, node_name_to_op_type[node_name], output_slot,
                     debug_op))
                break
        self._grpc_gated_tensors[matching_debug_op] = gated

      return self._grpc_gated_tensors[matching_debug_op]
Esempio n. 5
0
    def testParseDebugNodeName_invalidWatchedTensorName(self):
        invalid_debug_node_name_1 = "__dbg_node1_0_DebugIdentity"

        with self.assertRaisesRegexp(ValueError,
                                     "Invalid tensor name in debug node name"):
            debug_graphs.parse_debug_node_name(invalid_debug_node_name_1)
Esempio n. 6
0
    def testParseDebugNodeName_missingDebugOpIndex(self):
        invalid_debug_node_name_1 = "__dbg_node1:0_DebugIdentity"

        with self.assertRaisesRegexp(ValueError, "Invalid debug node name"):
            debug_graphs.parse_debug_node_name(invalid_debug_node_name_1)
Esempio n. 7
0
    def testParseDebugNodeName_invalidPrefix(self):
        invalid_debug_node_name_1 = "__copy_ns_a/ns_b/node_c:1_0_DebugIdentity"

        with self.assertRaisesRegexp(ValueError, "Invalid prefix"):
            debug_graphs.parse_debug_node_name(invalid_debug_node_name_1)