def _parse_graph(graph: torch._C.Graph, model: torch.nn.Module) -> List[IrNode]: """ Implements a depth-first graph extraction to obtain connectivity information in the form of an IrNodes list. Depth-first extraction is realized using recursion. :param trace: Pytorch JIT trace for model or a submodule :param model: Pytorch model to create connected graph from :return List of IrNodes created from traversing the trace graph """ ir_nodes_list = [] curr_inputs = [inp for inp in graph.inputs()] # A map of sub-graph models and node name that requires recursive parsing # modules that are being referenced within the sub-graph node_name_to_module = {curr_inputs[0].debugName(): model} for node in graph.nodes(): outputs = [output for output in node.outputs()] # retrieving a module reference if 'GetAttr' in node.kind(): # For GetAttr lines, the output name will be referring to the module, and not the module's output(s) assert len(outputs) == 1 node_name = outputs[0].debugName() assert node_name not in node_name_to_module module = _get_module_instance(node, node_name_to_module) node_name_to_module[node_name] = module else: op_type: str = ConnectedGraph._parse_op_type(node) if "Constant" not in op_type: outputs = [output for output in node.outputs()] ir_node = IrNode(node_type=op_type, inputs=[ inp for inp in node.inputs() if "Constant" not in ConnectedGraph._parse_op_type(inp.node()) ], outputs=outputs, module=None) ir_nodes_list.append(ir_node) for ir_node in ir_nodes_list: inputs = [] for inp in ir_node.inputs: if "GetAttr" in inp.node().kind(): if ir_node.node_type in ConnectedGraph.op_type_map.values(): module = node_name_to_module[ inp.node().input().debugName()] assert is_leaf_module(module) if ir_node.module is None: ir_node.module = module else: assert ir_node.module == module else: inputs.append(inp) ir_node.inputs = inputs return ir_nodes_list
def optimize_onnx(self, graph: torch._C.Graph) -> torch._C.Graph: if pytorch_pfn_extras.requires("1.9.0"): self.run_jit_pass(torch._C._jit_pass_onnx_scalar_type_analysis, graph, self.onnx_lowprecision_cast, self.opset_version) else: self.run_jit_pass(torch._C._jit_pass_onnx_scalar_type_analysis, graph) if self.do_constant_folding and self.opset_version in torch.onnx.constant_folding_opset_versions: folded: Dict[str, torch.IValue] = torch._C._jit_pass_onnx_constant_fold( # type: ignore[attr-defined] graph, self.vars, self.opset_version ) # Replace input with constant nodes input_table: Dict[str, torch._C.Value] = {i.debugName(): i for i in graph.inputs()} for k, t in folded.items(): c: torch._C.Value = graph.create("onnx::Constant", 1).output() assert isinstance(t, torch.Tensor) c.node().t_("value", cast(torch.Tensor, t)) graph.prependNode(c.node()) # TODO(twata): Determine folded nodes from original graph and document it self.node_doc_string[c.node()] = f"Constant folded node: {input_table[k]}" input_table[k].replaceAllUsesWith(c) c.copyMetadata(input_table[k]) self.attrs[_unique_id(c)] = ONNXValueID(k) self.vars[k] = t del input_table[k] for _ in range(len(list(graph.inputs())) - len(input_table)): graph.eraseInput(len(input_table)) torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) # type: ignore[attr-defined] if self.onnx_peephole: self.run_jit_pass(torch._C._jit_pass_onnx_peephole, graph, self.opset_version, self.fixed_batch_size) return graph
def gen_concat(g: torch._C.Graph, *args: Any) -> torch._C.Value: seq: List[torch._C.Value] = [] for i in args: if i.type().kind() == "IntType" or len(i.type().sizes()) == 0: seq.append( sym_hel._unsqueeze_helper(g, i, axes_i=[0]) # type: ignore[no-untyped-call,call-arg] ) else: seq.append(i) return cast(torch._C.Value, g.op("Concat", *seq, axis_i=0))
def gen_const(g: torch._C.Graph, value: Any = None) -> torch._C.Value: c = cast(torch._C.Value, g.op("Constant")) if n.kindOf("value") == "ival": ival = n.output().toIValue() if isinstance(ival, list) and not isinstance(ival[0], (int, float)): vals: List[torch._C.Value] = [] for i in ival: if isinstance(i, torch.Tensor): vals.append(cast(torch._C.Value, g.op("prim::Constant", value_t=i))) else: assert i is None vals.append(cast(torch._C.Value, g.op("prim::Constant"))) vals[-1].setType(torch._C.NoneType.get()) c = cast(torch._C.Value, g.op("prim::ListConstruct")) for v in vals: c.node().addInput(v) else: c.node().t_("value", torch.tensor(ival)) else: c.node().copyAttributes(n) return c
def _new_node(g: torch._C.Graph, opname: str, outputs, *args, **kwargs): if "::" in opname: aten = False ns_opname = opname else: aten = kwargs.pop("aten", False) ns = "aten" if aten else "onnx" ns_opname = ns + "::" + opname n = g.create(ns_opname, args, outputs) # type: ignore[attr-defined] for k, v in sorted(kwargs.items()): # TODO: enable inplace in aten exporting mode. if k == "inplace": continue _add_attribute(n, k, v, aten=aten) return n
def _export_jit_graph_to_onnx_model_proto(graph: torch._C.Graph, operator_export_type: int): from torch.onnx.symbolic_helper import _set_onnx_shape_inference, _set_operator_export_type, _set_opset_version _set_onnx_shape_inference(True) _set_operator_export_type(operator_export_type) torch._C._jit_pass_run_decompositions(graph) graph = torch.onnx.utils._optimize_graph(graph, operator_export_type, params_dict={}) proto, _, _, _ = graph._export_onnx( {}, torch.onnx._globals.GLOBALS.export_onnx_opset_version, {}, False, operator_export_type, False, False, {}, True, "", {}, ) return proto
def handle_if(self, g: torch._C.Graph, n: torch._C.Node) -> None: # Generated onnx node doc string should be added later since DCE isn't completed yet doc_str: str = f""" ## Original node {n} ## Scope {n.scopeName()} ## Source Range ``` {n.sourceRange()} ``` """ # If node will reused to keep graph lint happy for b in n.blocks(): block_nodes = list(b.nodes()) for b_n in block_nodes: self.generate_onnx_node(cast(torch._C.Graph, b), b_n) if not self.strip_doc_string: self.node_doc_string[n] = doc_str # Move to last of graph to keep the execution order of node n.moveBefore(g.return_node())
def _graph_op( g: torch._C.Graph, opname: str, *raw_args: torch._C.Value, outputs: int = 1, **kwargs, ) -> Union[torch._C.Value, Tuple[torch._C.Value, ...]]: r"""Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs". The set of operators and the inputs/attributes they take is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md This function is monkey-patched onto Graph. Args: g: The Torch graph. opname: The ONNX operator name, e.g., `Abs` or `Add`. TODO(justinchu): Update examples to correct ones. raw_args: The inputs to the operator; usually provided as arguments to the `symbolic` definition. outputs: The number of outputs this operator returns. By default an operator is assumed to return a single output. If `outputs` is greater than one, this functions returns a tuple of output `Node`, representing each output of the ONNX operator in positional. kwargs: The attributes of the ONNX operator, whose keys are named according to the following convention: `alpha_f` indicates the `alpha` attribute with type `f`. The valid type specifiers are `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute specified with type float accepts either a single float, or a list of floats (e.g., you would say `dims_i` for a `dims` attribute that takes a list of integers). Returns: The node representing the single output of this operator (see the `outputs` keyword argument for multi-return nodes). """ # Filter out None attributes, this can be convenient client side because # now they can pass through None attributes, and have them not show up kwargs = {k: v for k, v in kwargs.items() if v is not None} def const_if_tensor(arg): if arg is None: return arg elif isinstance(arg, torch._C.Value): return arg else: return g.op("Constant", value_z=arg) # type: ignore[attr-defined] args = [const_if_tensor(arg) for arg in raw_args] n = g.insertNode(_new_node(g, opname, outputs, *args, **kwargs)) # type: ignore[attr-defined] # Import utils to get _params_dict because it is a global that is accessed by c++ code from torch.onnx import utils if GLOBALS.onnx_shape_inference: torch._C._jit_pass_onnx_node_shape_type_inference( n, utils._params_dict, GLOBALS.export_onnx_opset_version ) if outputs == 1: return n.output() return tuple(n.outputs())
def generate_proto_nodes( self, g: torch._C.Graph, onnx_vars: Dict[TorchValueID, onnx.TensorProto], val_tab: Dict[TorchValueID, ONNXValueID], ) -> Tuple[List[onnx.NodeProto], Dict[TorchValueID, onnx.TensorProto], Dict[TorchValueID, ONNXValueID],]: node_name_counter: int = 0 def node_name(n: torch._C.Node) -> str: nonlocal node_name_counter op = n.kind().split("::")[-1] node_name_counter += 1 return f"{op}_{node_name_counter - 1}" val_tab_rev: Dict[ONNXValueID, TorchValueID] = {v: k for k, v in val_tab.items()} def register_val_name(id: TorchValueID, name: ONNXValueID, shadow: bool = False) -> ONNXValueID: assert id not in val_tab, f"{id} already registered in {g}" if shadow: new_name = name c = 1 while new_name in val_tab_rev: new_name = ONNXValueID(f"{name}_{c}") c += 1 name = new_name else: assert name not in val_tab_rev, f"{name} already registered in {g}" val_tab_rev[name] = id val_tab[id] = name assert len(val_tab_rev) == len(val_tab) return name def value_name(v: torch._C.Value) -> ONNXValueID: if _unique_id(v) in self.attrs: return self.attrs[_unique_id(v)] n: torch._C.Node = v.node() or v.uses()[0].user scope: str = self.node_scope.get(n, n.scopeName()) if len(scope) > 0: scope += "." scope = _remove_prefix(scope.split("/")[-1], "__module.") scope = _remove_prefix(scope, f"{_ppe_ignore_scope}.") return ONNXValueID(f"{scope}{v.debugName()}") def block2subgraph(name: str, b: torch._C.Block, doc_string: str) -> onnx.GraphProto: branch_nodes, _, _ = self.generate_proto_nodes(cast(torch._C.Graph, b), onnx_vars, val_tab) branch_inputs: List[onnx.ValueInfoProto] = [] for i in b.inputs(): branch_inputs.append(onnx.ValueInfoProto()) branch_inputs[-1].name = val_tab[_unique_id(i)] if not self.strip_doc_string: branch_inputs[-1].doc_string = repr(i) branch_outputs: List[onnx.ValueInfoProto] = [] for i in b.outputs(): branch_outputs.append(onnx.ValueInfoProto()) branch_outputs[-1].name = val_tab[_unique_id(i)] if not self.strip_doc_string: branch_outputs[-1].doc_string = repr(i) branch_graph: onnx.GraphProto = onnx.helper.make_graph( name=name, nodes=branch_nodes, # TODO(twata): Support initializers if needed inputs=branch_inputs, outputs=branch_outputs, doc_string=doc_string, ) return branch_graph # Nodes and initializers onnx_nodes: List[onnx.NodeProto] = [] self_count: int = 0 # Run only in root graph if self.g == g: if self.input_names is not None: for idx, v in enumerate(g.inputs()): if self.is_self(v): # Skip module's self input self_count += 1 continue register_val_name(_unique_id(v), ONNXValueID(self.input_names[idx - self_count])) assert (len(list(g.inputs())) - self_count) == len(self.input_names) if self.output_names is not None: if len(self.output_names) != len(list(g.outputs())): warnings.warn(f"Specified output_names ({self.output_names}) count and graph outputs ({list(g.outputs())}) count differ") for idx, v in enumerate(g.outputs()): if idx >= len(self.output_names): break register_val_name(_unique_id(v), ONNXValueID(self.output_names[idx])) none_nodes: List[torch._C.Node] = [] for n in g.nodes(): # Skip None value node if n.mustBeNone(): none_nodes.append(n) continue if n.kind() == "prim::GetAttr": continue if n.kind() == "onnx::Constant" : if len(n.output().uses()) == 0: warnings.warn(f"Unused constant left: {n}") continue # Skip constant folded initialzers if _unique_id(n.output()) in self.attrs: continue for i in n.inputs(): if self.is_self(i): continue if i.node() is not None and i.node() in none_nodes: continue if _unique_id(i) in self.attrs and _unique_id(i) not in onnx_vars: k: ONNXValueID = self.attrs[_unique_id(i)] assert isinstance(self.vars[k], torch.Tensor) t: torch.Tensor = cast(torch.Tensor, self.vars[k]) onnx_vars[_unique_id(i)] = _tensor_to_proto(t, name=k) register_val_name(_unique_id(i), value_name(i), shadow=True) continue if _unique_id(i) not in val_tab: register_val_name(_unique_id(v), value_name(i)) for o in n.outputs(): if _unique_id(o) not in val_tab: register_val_name(_unique_id(o), value_name(o), shadow=True) def assign_onnx_values( onnx_values: List[str], prefix: str, torch_values: Iterator[torch._C.Value], ) -> None: assert len(onnx_values) == 0 for v in torch_values: if v.node() is not None and v.node() in none_nodes: onnx_values.append("") continue k: ONNXValueID = val_tab.get(_unique_id(v), value_name(v)) if _unique_id(v) not in val_tab: register_val_name(_unique_id(v), k) onnx_values.append(k) new_nd = onnx.NodeProto() new_nd.name = node_name(n) new_nd.op_type = n.kind().split("::")[-1] if n.kind() == "prim::If": if n in self.node_doc_string: new_nd.doc_string = f"""## Symbolic node {n} {self.node_doc_string[n]}""" blocks: List[torch._C.Block] = list(n.blocks()) assert len(blocks) == 2 for attr_name, block in zip(["then_branch", "else_branch"], blocks): sub_g = block2subgraph(f"{new_nd.name}_{attr_name}", block, new_nd.doc_string) new_nd.attribute.append(onnx.helper.make_attribute(attr_name, sub_g)) else: assert len(list(n.blocks())) == 0, f"Node with block needs to be handled separately: {n}" if n in self.node_doc_string: new_nd.doc_string = self.node_doc_string[n] for attr_name in n.attributeNames(): if n.kindOf(attr_name) == "t": attr = onnx.helper.make_attribute(attr_name, _tensor_to_proto(n.t(attr_name))) else: attr = onnx.helper.make_attribute(attr_name, n[attr_name]) new_nd.attribute.append(attr) assign_onnx_values(new_nd.input, new_nd.name, n.inputs()) assign_onnx_values(new_nd.output, new_nd.name, n.outputs()) onnx_nodes.append(new_nd) return onnx_nodes, onnx_vars, val_tab
def gen_aten_node(g: torch._C.Graph, *inputs: Any) -> Union[torch._C.Value, Sequence[torch._C.Value]]: ret = g.op("ATen", *inputs, outputs=len(list(n.outputs()))) v: torch._C.Value = cast(torch._C.Value, ret) if n.outputsSize() == 1 else cast(Sequence[torch._C.Value], ret)[-1] v.node().copyAttributes(n) v.node().s_("operator", n.kind().split("::")[-1]) return ret
def gen_seq(g: torch._C.Graph, *args: Any) -> torch._C.Value: if len(args) == 0: return cast(torch._C.Value, g.op("SequenceEmpty")) # TODO(twata): Set dtype attribute else: return cast(torch._C.Value, g.op("SequenceConstruct", *args))
def optimize_torch(self, graph: torch._C.Graph) -> torch._C.Graph: self.run_jit_pass(torch._C._jit_pass_inline_fork_wait, graph) # type: ignore[attr-defined] if self.torch_constant_prop: self.run_jit_pass(torch._C._jit_pass_constant_propagation, graph) # type: ignore[attr-defined] # _split_tensor_list_constants(graph, graph) # run dce to eliminate dead parts of the graph that might have been # left behind by things like symbolic_override self.run_jit_pass(torch._C._jit_pass_dce, graph) self.run_jit_pass(torch._C._jit_pass_canonicalize_graph_fuser_ops, graph) # type: ignore[attr-defined] torch._C._jit_pass_peephole(graph, True) # type: ignore[attr-defined] self.run_jit_pass(torch._C._jit_pass_fuse_addmm, graph) # type: ignore[attr-defined] torch._C._jit_pass_peephole(graph, True) # type: ignore[attr-defined] torch._C._jit_pass_lower_all_tuples(graph) # type: ignore[attr-defined] # in _jit_pass_onnx, symbolic functions are called for each node for conversion. # However, there are nodes that cannot be converted without additional context. # For example, the number of outputs from split # (and whether it is static or dynamic) is unknown # until the point where it is unpacked by listUnpack node. # This pass does a preprocess, and prepares the nodes such that enough # context can be received # by the symbolic function. # torch._C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, None) torch._C._jit_pass_onnx_preprocess(graph) # type: ignore[attr-defined] # onnx does not support tuples, so try to remove them torch._C._jit_pass_lint(graph) # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0 torch._C._jit_pass_prepare_division_for_onnx(graph) # type: ignore[attr-defined] torch._C._jit_pass_onnx_remove_print(graph) # type: ignore[attr-defined] torch._C._jit_pass_onnx_preprocess_caffe2(graph) # type: ignore[attr-defined] if self.operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: sym_hel._quantized_ops.clear() # Unpack quantized weights for conv and linear ops and insert into graph. torch._C._jit_pass_onnx_unpack_quantized_weights(graph, self.vars) # type: ignore[attr-defined] # Insert permutes before and after each conv op to ensure correct order. torch._C._jit_pass_onnx_quantization_insert_permutes(graph, self.vars) # type: ignore[attr-defined] # Find consecutive permutes that are no-ops and remove them. torch._C._jit_pass_custom_pattern_based_rewrite_graph( # type: ignore[attr-defined] """ graph(%Pi): %Pq = quantized::nhwc2nchw(%Pi) %Pr = quantized::nchw2nhwc(%Pq) return (%Pr)""", """ graph(%Ri): return (%Ri)""", graph, ) # onnx only supports tensors, so we turn all out number types into tensors torch._C._jit_pass_erase_number_types(graph) # type: ignore[attr-defined] if self.input_names is not None: input_names = self.input_names.copy() if self.self_id is not None: input_names.insert(0, cast(str, self.self_name)) assert len(list(graph.inputs())) == len(input_names) inputs = list(graph.inputs()) for idx, n in enumerate(input_names): inputs[idx].setDebugName(n) torch._C._jit_pass_onnx_set_dynamic_input_shape( # type: ignore[attr-defined] graph, self.dynamic_axes or {}, input_names or [] ) return graph