def _parse_graph(graph: torch._C.Graph, model: torch.nn.Module) -> List[IrNode]: """ Implements a depth-first graph extraction to obtain connectivity information in the form of an IrNodes list. Depth-first extraction is realized using recursion. :param trace: Pytorch JIT trace for model or a submodule :param model: Pytorch model to create connected graph from :return List of IrNodes created from traversing the trace graph """ ir_nodes_list = [] curr_inputs = [inp for inp in graph.inputs()] # A map of sub-graph models and node name that requires recursive parsing # modules that are being referenced within the sub-graph node_name_to_module = {curr_inputs[0].debugName(): model} for node in graph.nodes(): outputs = [output for output in node.outputs()] # retrieving a module reference if 'GetAttr' in node.kind(): # For GetAttr lines, the output name will be referring to the module, and not the module's output(s) assert len(outputs) == 1 node_name = outputs[0].debugName() assert node_name not in node_name_to_module module = _get_module_instance(node, node_name_to_module) node_name_to_module[node_name] = module else: op_type: str = ConnectedGraph._parse_op_type(node) if "Constant" not in op_type: outputs = [output for output in node.outputs()] ir_node = IrNode(node_type=op_type, inputs=[ inp for inp in node.inputs() if "Constant" not in ConnectedGraph._parse_op_type(inp.node()) ], outputs=outputs, module=None) ir_nodes_list.append(ir_node) for ir_node in ir_nodes_list: inputs = [] for inp in ir_node.inputs: if "GetAttr" in inp.node().kind(): if ir_node.node_type in ConnectedGraph.op_type_map.values(): module = node_name_to_module[ inp.node().input().debugName()] assert is_leaf_module(module) if ir_node.module is None: ir_node.module = module else: assert ir_node.module == module else: inputs.append(inp) ir_node.inputs = inputs return ir_nodes_list
def generate_proto_nodes( self, g: torch._C.Graph, onnx_vars: Dict[TorchValueID, onnx.TensorProto], val_tab: Dict[TorchValueID, ONNXValueID], ) -> Tuple[List[onnx.NodeProto], Dict[TorchValueID, onnx.TensorProto], Dict[TorchValueID, ONNXValueID],]: node_name_counter: int = 0 def node_name(n: torch._C.Node) -> str: nonlocal node_name_counter op = n.kind().split("::")[-1] node_name_counter += 1 return f"{op}_{node_name_counter - 1}" val_tab_rev: Dict[ONNXValueID, TorchValueID] = {v: k for k, v in val_tab.items()} def register_val_name(id: TorchValueID, name: ONNXValueID, shadow: bool = False) -> ONNXValueID: assert id not in val_tab, f"{id} already registered in {g}" if shadow: new_name = name c = 1 while new_name in val_tab_rev: new_name = ONNXValueID(f"{name}_{c}") c += 1 name = new_name else: assert name not in val_tab_rev, f"{name} already registered in {g}" val_tab_rev[name] = id val_tab[id] = name assert len(val_tab_rev) == len(val_tab) return name def value_name(v: torch._C.Value) -> ONNXValueID: if _unique_id(v) in self.attrs: return self.attrs[_unique_id(v)] n: torch._C.Node = v.node() or v.uses()[0].user scope: str = self.node_scope.get(n, n.scopeName()) if len(scope) > 0: scope += "." scope = _remove_prefix(scope.split("/")[-1], "__module.") scope = _remove_prefix(scope, f"{_ppe_ignore_scope}.") return ONNXValueID(f"{scope}{v.debugName()}") def block2subgraph(name: str, b: torch._C.Block, doc_string: str) -> onnx.GraphProto: branch_nodes, _, _ = self.generate_proto_nodes(cast(torch._C.Graph, b), onnx_vars, val_tab) branch_inputs: List[onnx.ValueInfoProto] = [] for i in b.inputs(): branch_inputs.append(onnx.ValueInfoProto()) branch_inputs[-1].name = val_tab[_unique_id(i)] if not self.strip_doc_string: branch_inputs[-1].doc_string = repr(i) branch_outputs: List[onnx.ValueInfoProto] = [] for i in b.outputs(): branch_outputs.append(onnx.ValueInfoProto()) branch_outputs[-1].name = val_tab[_unique_id(i)] if not self.strip_doc_string: branch_outputs[-1].doc_string = repr(i) branch_graph: onnx.GraphProto = onnx.helper.make_graph( name=name, nodes=branch_nodes, # TODO(twata): Support initializers if needed inputs=branch_inputs, outputs=branch_outputs, doc_string=doc_string, ) return branch_graph # Nodes and initializers onnx_nodes: List[onnx.NodeProto] = [] self_count: int = 0 # Run only in root graph if self.g == g: if self.input_names is not None: for idx, v in enumerate(g.inputs()): if self.is_self(v): # Skip module's self input self_count += 1 continue register_val_name(_unique_id(v), ONNXValueID(self.input_names[idx - self_count])) assert (len(list(g.inputs())) - self_count) == len(self.input_names) if self.output_names is not None: if len(self.output_names) != len(list(g.outputs())): warnings.warn(f"Specified output_names ({self.output_names}) count and graph outputs ({list(g.outputs())}) count differ") for idx, v in enumerate(g.outputs()): if idx >= len(self.output_names): break register_val_name(_unique_id(v), ONNXValueID(self.output_names[idx])) none_nodes: List[torch._C.Node] = [] for n in g.nodes(): # Skip None value node if n.mustBeNone(): none_nodes.append(n) continue if n.kind() == "prim::GetAttr": continue if n.kind() == "onnx::Constant" : if len(n.output().uses()) == 0: warnings.warn(f"Unused constant left: {n}") continue # Skip constant folded initialzers if _unique_id(n.output()) in self.attrs: continue for i in n.inputs(): if self.is_self(i): continue if i.node() is not None and i.node() in none_nodes: continue if _unique_id(i) in self.attrs and _unique_id(i) not in onnx_vars: k: ONNXValueID = self.attrs[_unique_id(i)] assert isinstance(self.vars[k], torch.Tensor) t: torch.Tensor = cast(torch.Tensor, self.vars[k]) onnx_vars[_unique_id(i)] = _tensor_to_proto(t, name=k) register_val_name(_unique_id(i), value_name(i), shadow=True) continue if _unique_id(i) not in val_tab: register_val_name(_unique_id(v), value_name(i)) for o in n.outputs(): if _unique_id(o) not in val_tab: register_val_name(_unique_id(o), value_name(o), shadow=True) def assign_onnx_values( onnx_values: List[str], prefix: str, torch_values: Iterator[torch._C.Value], ) -> None: assert len(onnx_values) == 0 for v in torch_values: if v.node() is not None and v.node() in none_nodes: onnx_values.append("") continue k: ONNXValueID = val_tab.get(_unique_id(v), value_name(v)) if _unique_id(v) not in val_tab: register_val_name(_unique_id(v), k) onnx_values.append(k) new_nd = onnx.NodeProto() new_nd.name = node_name(n) new_nd.op_type = n.kind().split("::")[-1] if n.kind() == "prim::If": if n in self.node_doc_string: new_nd.doc_string = f"""## Symbolic node {n} {self.node_doc_string[n]}""" blocks: List[torch._C.Block] = list(n.blocks()) assert len(blocks) == 2 for attr_name, block in zip(["then_branch", "else_branch"], blocks): sub_g = block2subgraph(f"{new_nd.name}_{attr_name}", block, new_nd.doc_string) new_nd.attribute.append(onnx.helper.make_attribute(attr_name, sub_g)) else: assert len(list(n.blocks())) == 0, f"Node with block needs to be handled separately: {n}" if n in self.node_doc_string: new_nd.doc_string = self.node_doc_string[n] for attr_name in n.attributeNames(): if n.kindOf(attr_name) == "t": attr = onnx.helper.make_attribute(attr_name, _tensor_to_proto(n.t(attr_name))) else: attr = onnx.helper.make_attribute(attr_name, n[attr_name]) new_nd.attribute.append(attr) assign_onnx_values(new_nd.input, new_nd.name, n.inputs()) assign_onnx_values(new_nd.output, new_nd.name, n.outputs()) onnx_nodes.append(new_nd) return onnx_nodes, onnx_vars, val_tab