def _parse_graph(graph: torch._C.Graph,
                 model: torch.nn.Module) -> List[IrNode]:
    """
    Implements a depth-first graph extraction to obtain connectivity information in the form of an IrNodes list.
    Depth-first extraction is realized using recursion.

    :param trace: Pytorch JIT trace for model or a submodule
    :param model: Pytorch model to create connected graph from
    :return List of IrNodes created from traversing the trace graph
    """
    ir_nodes_list = []
    curr_inputs = [inp for inp in graph.inputs()]

    # A map of sub-graph models and node name that requires recursive parsing
    # modules that are being referenced within the sub-graph
    node_name_to_module = {curr_inputs[0].debugName(): model}
    for node in graph.nodes():
        outputs = [output for output in node.outputs()]

        # retrieving a module reference
        if 'GetAttr' in node.kind():
            # For GetAttr lines, the output name will be referring to the module, and not the module's output(s)
            assert len(outputs) == 1
            node_name = outputs[0].debugName()
            assert node_name not in node_name_to_module
            module = _get_module_instance(node, node_name_to_module)
            node_name_to_module[node_name] = module
        else:
            op_type: str = ConnectedGraph._parse_op_type(node)
            if "Constant" not in op_type:
                outputs = [output for output in node.outputs()]
                ir_node = IrNode(node_type=op_type,
                                 inputs=[
                                     inp for inp in node.inputs()
                                     if "Constant" not in
                                     ConnectedGraph._parse_op_type(inp.node())
                                 ],
                                 outputs=outputs,
                                 module=None)
                ir_nodes_list.append(ir_node)

    for ir_node in ir_nodes_list:
        inputs = []
        for inp in ir_node.inputs:
            if "GetAttr" in inp.node().kind():
                if ir_node.node_type in ConnectedGraph.op_type_map.values():
                    module = node_name_to_module[
                        inp.node().input().debugName()]
                    assert is_leaf_module(module)
                    if ir_node.module is None:
                        ir_node.module = module
                    else:
                        assert ir_node.module == module
            else:
                inputs.append(inp)
        ir_node.inputs = inputs

    return ir_nodes_list
Beispiel #2
0
    def generate_proto_nodes(
        self,
        g: torch._C.Graph,
        onnx_vars: Dict[TorchValueID, onnx.TensorProto],
        val_tab: Dict[TorchValueID, ONNXValueID],
    ) -> Tuple[List[onnx.NodeProto], Dict[TorchValueID, onnx.TensorProto], Dict[TorchValueID, ONNXValueID],]:
        node_name_counter: int = 0

        def node_name(n: torch._C.Node) -> str:
            nonlocal node_name_counter
            op = n.kind().split("::")[-1]
            node_name_counter += 1
            return f"{op}_{node_name_counter - 1}"

        val_tab_rev: Dict[ONNXValueID, TorchValueID] = {v: k for k, v in val_tab.items()}

        def register_val_name(id: TorchValueID, name: ONNXValueID, shadow: bool = False) -> ONNXValueID:
            assert id not in val_tab, f"{id} already registered in {g}"
            if shadow:
                new_name = name
                c = 1
                while new_name in val_tab_rev:
                    new_name = ONNXValueID(f"{name}_{c}")
                    c += 1
                name = new_name
            else:
                assert name not in val_tab_rev, f"{name} already registered in {g}"
            val_tab_rev[name] = id
            val_tab[id] = name
            assert len(val_tab_rev) == len(val_tab)
            return name

        def value_name(v: torch._C.Value) -> ONNXValueID:
            if _unique_id(v) in self.attrs:
                return self.attrs[_unique_id(v)]

            n: torch._C.Node = v.node() or v.uses()[0].user
            scope: str = self.node_scope.get(n, n.scopeName())
            if len(scope) > 0:
                scope += "."
            scope = _remove_prefix(scope.split("/")[-1], "__module.")
            scope = _remove_prefix(scope, f"{_ppe_ignore_scope}.")
            return ONNXValueID(f"{scope}{v.debugName()}")

        def block2subgraph(name: str, b: torch._C.Block, doc_string: str) -> onnx.GraphProto:
            branch_nodes, _, _ = self.generate_proto_nodes(cast(torch._C.Graph, b), onnx_vars, val_tab)
            branch_inputs: List[onnx.ValueInfoProto] = []
            for i in b.inputs():
                branch_inputs.append(onnx.ValueInfoProto())
                branch_inputs[-1].name = val_tab[_unique_id(i)]
                if not self.strip_doc_string:
                    branch_inputs[-1].doc_string = repr(i)
            branch_outputs: List[onnx.ValueInfoProto] = []
            for i in b.outputs():
                branch_outputs.append(onnx.ValueInfoProto())
                branch_outputs[-1].name = val_tab[_unique_id(i)]
                if not self.strip_doc_string:
                    branch_outputs[-1].doc_string = repr(i)

            branch_graph: onnx.GraphProto = onnx.helper.make_graph(
                name=name,
                nodes=branch_nodes,
                # TODO(twata): Support initializers if needed
                inputs=branch_inputs,
                outputs=branch_outputs,
                doc_string=doc_string,
            )

            return branch_graph

        # Nodes and initializers
        onnx_nodes: List[onnx.NodeProto] = []
        self_count: int = 0
        # Run only in root graph
        if self.g == g:
            if self.input_names is not None:
                for idx, v in enumerate(g.inputs()):
                    if self.is_self(v):  # Skip module's self input
                        self_count += 1
                        continue
                    register_val_name(_unique_id(v), ONNXValueID(self.input_names[idx - self_count]))
                assert (len(list(g.inputs())) - self_count) == len(self.input_names)
            if self.output_names is not None:
                if len(self.output_names) != len(list(g.outputs())):
                    warnings.warn(f"Specified output_names ({self.output_names}) count and graph outputs ({list(g.outputs())}) count differ")
                for idx, v in enumerate(g.outputs()):
                    if idx >= len(self.output_names):
                        break
                    register_val_name(_unique_id(v), ONNXValueID(self.output_names[idx]))
        none_nodes: List[torch._C.Node] = []
        for n in g.nodes():
            # Skip None value node
            if n.mustBeNone():
                none_nodes.append(n)
                continue
            if n.kind() == "prim::GetAttr":
                continue
            if n.kind() == "onnx::Constant" :
                if len(n.output().uses()) == 0:
                    warnings.warn(f"Unused constant left: {n}")
                    continue
                # Skip constant folded initialzers
                if _unique_id(n.output()) in self.attrs:
                    continue
            for i in n.inputs():
                if self.is_self(i):
                    continue
                if i.node() is not None and i.node() in none_nodes:
                    continue
                if _unique_id(i) in self.attrs and _unique_id(i) not in onnx_vars:
                    k: ONNXValueID = self.attrs[_unique_id(i)]
                    assert isinstance(self.vars[k], torch.Tensor)
                    t: torch.Tensor = cast(torch.Tensor, self.vars[k])
                    onnx_vars[_unique_id(i)] = _tensor_to_proto(t, name=k)
                    register_val_name(_unique_id(i), value_name(i), shadow=True)
                    continue
                if _unique_id(i) not in val_tab:
                    register_val_name(_unique_id(v), value_name(i))

            for o in n.outputs():
                if _unique_id(o) not in val_tab:
                    register_val_name(_unique_id(o), value_name(o), shadow=True)

            def assign_onnx_values(
                onnx_values: List[str],
                prefix: str,
                torch_values: Iterator[torch._C.Value],
            ) -> None:
                assert len(onnx_values) == 0
                for v in torch_values:
                    if v.node() is not None and v.node() in none_nodes:
                        onnx_values.append("")
                        continue
                    k: ONNXValueID = val_tab.get(_unique_id(v), value_name(v))
                    if _unique_id(v) not in val_tab:
                        register_val_name(_unique_id(v), k)
                    onnx_values.append(k)

            new_nd = onnx.NodeProto()
            new_nd.name = node_name(n)
            new_nd.op_type = n.kind().split("::")[-1]
            if n.kind() == "prim::If":
                if n in self.node_doc_string:
                    new_nd.doc_string = f"""## Symbolic node
{n}
{self.node_doc_string[n]}"""
                blocks: List[torch._C.Block] = list(n.blocks())
                assert len(blocks) == 2
                for attr_name, block in zip(["then_branch", "else_branch"], blocks):
                    sub_g = block2subgraph(f"{new_nd.name}_{attr_name}", block, new_nd.doc_string)
                    new_nd.attribute.append(onnx.helper.make_attribute(attr_name, sub_g))
            else:
                assert len(list(n.blocks())) == 0, f"Node with block needs to be handled separately: {n}"
                if n in self.node_doc_string:
                    new_nd.doc_string = self.node_doc_string[n]
                for attr_name in n.attributeNames():
                    if n.kindOf(attr_name) == "t":
                        attr = onnx.helper.make_attribute(attr_name, _tensor_to_proto(n.t(attr_name)))
                    else:
                        attr = onnx.helper.make_attribute(attr_name, n[attr_name])
                    new_nd.attribute.append(attr)
            assign_onnx_values(new_nd.input, new_nd.name, n.inputs())
            assign_onnx_values(new_nd.output, new_nd.name, n.outputs())
            onnx_nodes.append(new_nd)

        return onnx_nodes, onnx_vars, val_tab