Пример #1
0
    def get_tensor_history(self, node_name, depth=0):
        """
        Get the tensor history of a specified node.

        Args:
            node_name (str): The debug name of the node.
            depth (int): The number of layers the user
                wants to trace. Default is 0.

        Returns:
            dict, basic tensor history, only including tensor name and tensor type and node type.
        """
        self._graph_exists()
        if not self._graph.exist_node(node_name):
            raise DebuggerNodeNotInGraphError(node_name)

        tensor_history, cur_outputs_nums = self._graph.get_tensor_history(
            node_name, depth)
        # add the tensor type for tensor history
        self._update_tensor_history(tensor_history[0:cur_outputs_nums],
                                    'output')
        self._update_tensor_history(tensor_history[cur_outputs_nums:], 'input')
        log.debug("Get %d tensors in tensor history for node <%s>.",
                  len(tensor_history), node_name)
        return {'tensor_history': tensor_history}
Пример #2
0
    def get_full_name_by_node_name(self, node_name):
        """Get full name by node name."""
        if not node_name:
            return ''
        node = self._normal_node_map.get(node_name)
        if not node:
            log.error("Node <%s> is not in graph.", node_name)
            raise DebuggerNodeNotInGraphError(node_name=node_name)

        return node.full_name
Пример #3
0
    def validate_node_name(self, node_name, graph_name):
        """
        Validate the graph exist the specified node.

        Args:
            node_name (str): The ui node name.
            graph_name (str): The graph name.

        Raises:
            DebuggerNodeNotInGraphError: If can not find the node in all graphs.
        """
        graph = self._get_graph(graph_name=graph_name)
        if not graph.exist_node(name=node_name):
            log.error("graph %s doesn't find node: %s.", graph_name, node_name)
            raise DebuggerNodeNotInGraphError(node_name)
Пример #4
0
    def get_node_type(self, node_name):
        """
        Get the type of the node.

        Args:
            node_name (str): The full name of the node with its scope.

        Returns:
            str, node type or name_scope.
        """
        if not node_name:
            return 'name_scope'
        node = self._normal_node_map.get(node_name)
        if not node:
            log.error("Node <%s> is not in graph.", node_name)
            raise DebuggerNodeNotInGraphError(node_name=node_name)

        return node.type
Пример #5
0
    def list_nodes(self, scope):
        """
        Get the nodes of every layer in graph.

        Args:
            scope (str): The name of a scope.

        Returns:
            TypedDict('Nodes', {'nodes': list[Node]}), format is {'nodes': [<Node object>]}.
                example:
                    {
                      "nodes" : [
                        {
                          "attr" :
                          {
                            "index" : "i: 0\n"
                          },
                          "input" : {},
                          "name" : "input_tensor",
                          "output" :
                          {
                            "Default/TensorAdd-op17" :
                            {
                              "edge_type" : "data",
                              "scope" : "name_scope",
                              "shape" : [1, 16, 128, 128]
                            }
                          },
                          "output_i" : -1,
                          "proxy_input" : {},
                          "proxy_output" : {},
                          "independent_layout" : False,
                          "subnode_count" : 0,
                          "type" : "Data"
                        }
                      ]
                    }
        """
        if scope and not self._graph.exist_node(scope):
            raise DebuggerNodeNotInGraphError(node_name=scope)

        nodes = self._graph.list_node_by_scope(scope=scope)
        return {'nodes': nodes}
Пример #6
0
    def get_node_type(self, node_name):
        """
        Get the type of the node.

        Args:
            node_name (str): The full name of the node with its scope.

        Returns:
            A string, leaf or name_scope.
        """
        if node_name and not self.exist_node(name=node_name):
            raise DebuggerNodeNotInGraphError(node_name=node_name)

        node = self._leaf_nodes.get(node_name)
        if node is not None:
            node_type = node.type
        else:
            node_type = NodeTypeEnum.NAME_SCOPE.value

        return node_type