Esempio n. 1
0
def _parse_graph(graph: torch._C.Graph,
                 model: torch.nn.Module) -> List[IrNode]:
    """
    Implements a depth-first graph extraction to obtain connectivity information in the form of an IrNodes list.
    Depth-first extraction is realized using recursion.

    :param trace: Pytorch JIT trace for model or a submodule
    :param model: Pytorch model to create connected graph from
    :return List of IrNodes created from traversing the trace graph
    """
    ir_nodes_list = []
    curr_inputs = [inp for inp in graph.inputs()]

    # A map of sub-graph models and node name that requires recursive parsing
    # modules that are being referenced within the sub-graph
    node_name_to_module = {curr_inputs[0].debugName(): model}
    for node in graph.nodes():
        outputs = [output for output in node.outputs()]

        # retrieving a module reference
        if 'GetAttr' in node.kind():
            # For GetAttr lines, the output name will be referring to the module, and not the module's output(s)
            assert len(outputs) == 1
            node_name = outputs[0].debugName()
            assert node_name not in node_name_to_module
            module = _get_module_instance(node, node_name_to_module)
            node_name_to_module[node_name] = module
        else:
            op_type: str = ConnectedGraph._parse_op_type(node)
            if "Constant" not in op_type:
                outputs = [output for output in node.outputs()]
                ir_node = IrNode(node_type=op_type,
                                 inputs=[
                                     inp for inp in node.inputs()
                                     if "Constant" not in
                                     ConnectedGraph._parse_op_type(inp.node())
                                 ],
                                 outputs=outputs,
                                 module=None)
                ir_nodes_list.append(ir_node)

    for ir_node in ir_nodes_list:
        inputs = []
        for inp in ir_node.inputs:
            if "GetAttr" in inp.node().kind():
                if ir_node.node_type in ConnectedGraph.op_type_map.values():
                    module = node_name_to_module[
                        inp.node().input().debugName()]
                    assert is_leaf_module(module)
                    if ir_node.module is None:
                        ir_node.module = module
                    else:
                        assert ir_node.module == module
            else:
                inputs.append(inp)
        ir_node.inputs = inputs

    return ir_nodes_list
Esempio n. 2
0
    def optimize_onnx(self, graph: torch._C.Graph) -> torch._C.Graph:
        if pytorch_pfn_extras.requires("1.9.0"):
            self.run_jit_pass(torch._C._jit_pass_onnx_scalar_type_analysis, graph, self.onnx_lowprecision_cast, self.opset_version)
        else:
            self.run_jit_pass(torch._C._jit_pass_onnx_scalar_type_analysis, graph)

        if self.do_constant_folding and self.opset_version in torch.onnx.constant_folding_opset_versions:
            folded: Dict[str, torch.IValue] = torch._C._jit_pass_onnx_constant_fold(  # type: ignore[attr-defined]
                graph, self.vars, self.opset_version
            )
            # Replace input with constant nodes
            input_table: Dict[str, torch._C.Value] = {i.debugName(): i for i in graph.inputs()}
            for k, t in folded.items():
                c: torch._C.Value = graph.create("onnx::Constant", 1).output()
                assert isinstance(t, torch.Tensor)
                c.node().t_("value", cast(torch.Tensor, t))
                graph.prependNode(c.node())
                # TODO(twata): Determine folded nodes from original graph and document it
                self.node_doc_string[c.node()] = f"Constant folded node: {input_table[k]}"
                input_table[k].replaceAllUsesWith(c)
                c.copyMetadata(input_table[k])
                self.attrs[_unique_id(c)] = ONNXValueID(k)
                self.vars[k] = t
                del input_table[k]
            for _ in range(len(list(graph.inputs())) - len(input_table)):
                graph.eraseInput(len(input_table))
            torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)  # type: ignore[attr-defined]

        if self.onnx_peephole:
            self.run_jit_pass(torch._C._jit_pass_onnx_peephole, graph, self.opset_version, self.fixed_batch_size)

        return graph
Esempio n. 3
0
 def gen_concat(g: torch._C.Graph, *args: Any) -> torch._C.Value:
     seq: List[torch._C.Value] = []
     for i in args:
         if i.type().kind() == "IntType" or len(i.type().sizes()) == 0:
             seq.append(
                 sym_hel._unsqueeze_helper(g, i, axes_i=[0])  # type: ignore[no-untyped-call,call-arg]
             )
         else:
             seq.append(i)
     return cast(torch._C.Value, g.op("Concat", *seq, axis_i=0))
Esempio n. 4
0
 def gen_const(g: torch._C.Graph, value: Any = None) -> torch._C.Value:
     c = cast(torch._C.Value, g.op("Constant"))
     if n.kindOf("value") == "ival":
         ival = n.output().toIValue()
         if isinstance(ival, list) and not isinstance(ival[0], (int, float)):
             vals: List[torch._C.Value] = []
             for i in ival:
                 if isinstance(i, torch.Tensor):
                     vals.append(cast(torch._C.Value, g.op("prim::Constant", value_t=i)))
                 else:
                     assert i is None
                     vals.append(cast(torch._C.Value, g.op("prim::Constant")))
                     vals[-1].setType(torch._C.NoneType.get())
             c = cast(torch._C.Value, g.op("prim::ListConstruct"))
             for v in vals:
                 c.node().addInput(v)
         else:
             c.node().t_("value", torch.tensor(ival))
     else:
         c.node().copyAttributes(n)
     return c
Esempio n. 5
0
def _new_node(g: torch._C.Graph, opname: str, outputs, *args, **kwargs):
    if "::" in opname:
        aten = False
        ns_opname = opname
    else:
        aten = kwargs.pop("aten", False)
        ns = "aten" if aten else "onnx"
        ns_opname = ns + "::" + opname
    n = g.create(ns_opname, args, outputs)  # type: ignore[attr-defined]
    for k, v in sorted(kwargs.items()):
        # TODO: enable inplace in aten exporting mode.
        if k == "inplace":
            continue
        _add_attribute(n, k, v, aten=aten)
    return n
Esempio n. 6
0
def _export_jit_graph_to_onnx_model_proto(graph: torch._C.Graph,
                                          operator_export_type: int):
    from torch.onnx.symbolic_helper import _set_onnx_shape_inference, _set_operator_export_type, _set_opset_version

    _set_onnx_shape_inference(True)
    _set_operator_export_type(operator_export_type)
    torch._C._jit_pass_run_decompositions(graph)
    graph = torch.onnx.utils._optimize_graph(graph,
                                             operator_export_type,
                                             params_dict={})
    proto, _, _, _ = graph._export_onnx(
        {},
        torch.onnx._globals.GLOBALS.export_onnx_opset_version,
        {},
        False,
        operator_export_type,
        False,
        False,
        {},
        True,
        "",
        {},
    )
    return proto
Esempio n. 7
0
    def handle_if(self, g: torch._C.Graph, n: torch._C.Node) -> None:
        # Generated onnx node doc string should be added later since DCE isn't completed yet
        doc_str: str = f"""
## Original node
{n}
## Scope
{n.scopeName()}
## Source Range
```
{n.sourceRange()}
```
"""

        # If node will reused to keep graph lint happy
        for b in n.blocks():
            block_nodes = list(b.nodes())
            for b_n in block_nodes:
                self.generate_onnx_node(cast(torch._C.Graph, b), b_n)

        if not self.strip_doc_string:
            self.node_doc_string[n] = doc_str

        # Move to last of graph to keep the execution order of node
        n.moveBefore(g.return_node())
Esempio n. 8
0
def _graph_op(
    g: torch._C.Graph,
    opname: str,
    *raw_args: torch._C.Value,
    outputs: int = 1,
    **kwargs,
) -> Union[torch._C.Value, Tuple[torch._C.Value, ...]]:
    r"""Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs".

    The set of operators and the inputs/attributes they take
    is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md

    This function is monkey-patched onto Graph.

    Args:
        g: The Torch graph.
        opname: The ONNX operator name, e.g., `Abs` or `Add`. TODO(justinchu): Update examples to correct ones.
        raw_args: The inputs to the operator; usually provided
            as arguments to the `symbolic` definition.
        outputs: The number of outputs this operator returns.
            By default an operator is assumed to return a single output.
            If `outputs` is greater than one, this functions returns a tuple
            of output `Node`, representing each output of the ONNX operator
            in positional.
        kwargs: The attributes of the ONNX operator, whose keys are named
            according to the following convention: `alpha_f` indicates
            the `alpha` attribute with type `f`.  The valid type specifiers are
            `f` (float), `i` (int), `s` (string) or `t` (Tensor).  An attribute
            specified with type float accepts either a single float, or a
            list of floats (e.g., you would say `dims_i` for a `dims` attribute
            that takes a list of integers).

    Returns:
        The node representing the single output of this operator (see the `outputs`
        keyword argument for multi-return nodes).
    """
    # Filter out None attributes, this can be convenient client side because
    # now they can pass through None attributes, and have them not show up
    kwargs = {k: v for k, v in kwargs.items() if v is not None}

    def const_if_tensor(arg):
        if arg is None:
            return arg
        elif isinstance(arg, torch._C.Value):
            return arg
        else:
            return g.op("Constant", value_z=arg)  # type: ignore[attr-defined]

    args = [const_if_tensor(arg) for arg in raw_args]
    n = g.insertNode(_new_node(g, opname, outputs, *args, **kwargs))  # type: ignore[attr-defined]

    # Import utils to get _params_dict because it is a global that is accessed by c++ code
    from torch.onnx import utils

    if GLOBALS.onnx_shape_inference:
        torch._C._jit_pass_onnx_node_shape_type_inference(
            n, utils._params_dict, GLOBALS.export_onnx_opset_version
        )

    if outputs == 1:
        return n.output()
    return tuple(n.outputs())
Esempio n. 9
0
    def generate_proto_nodes(
        self,
        g: torch._C.Graph,
        onnx_vars: Dict[TorchValueID, onnx.TensorProto],
        val_tab: Dict[TorchValueID, ONNXValueID],
    ) -> Tuple[List[onnx.NodeProto], Dict[TorchValueID, onnx.TensorProto], Dict[TorchValueID, ONNXValueID],]:
        node_name_counter: int = 0

        def node_name(n: torch._C.Node) -> str:
            nonlocal node_name_counter
            op = n.kind().split("::")[-1]
            node_name_counter += 1
            return f"{op}_{node_name_counter - 1}"

        val_tab_rev: Dict[ONNXValueID, TorchValueID] = {v: k for k, v in val_tab.items()}

        def register_val_name(id: TorchValueID, name: ONNXValueID, shadow: bool = False) -> ONNXValueID:
            assert id not in val_tab, f"{id} already registered in {g}"
            if shadow:
                new_name = name
                c = 1
                while new_name in val_tab_rev:
                    new_name = ONNXValueID(f"{name}_{c}")
                    c += 1
                name = new_name
            else:
                assert name not in val_tab_rev, f"{name} already registered in {g}"
            val_tab_rev[name] = id
            val_tab[id] = name
            assert len(val_tab_rev) == len(val_tab)
            return name

        def value_name(v: torch._C.Value) -> ONNXValueID:
            if _unique_id(v) in self.attrs:
                return self.attrs[_unique_id(v)]

            n: torch._C.Node = v.node() or v.uses()[0].user
            scope: str = self.node_scope.get(n, n.scopeName())
            if len(scope) > 0:
                scope += "."
            scope = _remove_prefix(scope.split("/")[-1], "__module.")
            scope = _remove_prefix(scope, f"{_ppe_ignore_scope}.")
            return ONNXValueID(f"{scope}{v.debugName()}")

        def block2subgraph(name: str, b: torch._C.Block, doc_string: str) -> onnx.GraphProto:
            branch_nodes, _, _ = self.generate_proto_nodes(cast(torch._C.Graph, b), onnx_vars, val_tab)
            branch_inputs: List[onnx.ValueInfoProto] = []
            for i in b.inputs():
                branch_inputs.append(onnx.ValueInfoProto())
                branch_inputs[-1].name = val_tab[_unique_id(i)]
                if not self.strip_doc_string:
                    branch_inputs[-1].doc_string = repr(i)
            branch_outputs: List[onnx.ValueInfoProto] = []
            for i in b.outputs():
                branch_outputs.append(onnx.ValueInfoProto())
                branch_outputs[-1].name = val_tab[_unique_id(i)]
                if not self.strip_doc_string:
                    branch_outputs[-1].doc_string = repr(i)

            branch_graph: onnx.GraphProto = onnx.helper.make_graph(
                name=name,
                nodes=branch_nodes,
                # TODO(twata): Support initializers if needed
                inputs=branch_inputs,
                outputs=branch_outputs,
                doc_string=doc_string,
            )

            return branch_graph

        # Nodes and initializers
        onnx_nodes: List[onnx.NodeProto] = []
        self_count: int = 0
        # Run only in root graph
        if self.g == g:
            if self.input_names is not None:
                for idx, v in enumerate(g.inputs()):
                    if self.is_self(v):  # Skip module's self input
                        self_count += 1
                        continue
                    register_val_name(_unique_id(v), ONNXValueID(self.input_names[idx - self_count]))
                assert (len(list(g.inputs())) - self_count) == len(self.input_names)
            if self.output_names is not None:
                if len(self.output_names) != len(list(g.outputs())):
                    warnings.warn(f"Specified output_names ({self.output_names}) count and graph outputs ({list(g.outputs())}) count differ")
                for idx, v in enumerate(g.outputs()):
                    if idx >= len(self.output_names):
                        break
                    register_val_name(_unique_id(v), ONNXValueID(self.output_names[idx]))
        none_nodes: List[torch._C.Node] = []
        for n in g.nodes():
            # Skip None value node
            if n.mustBeNone():
                none_nodes.append(n)
                continue
            if n.kind() == "prim::GetAttr":
                continue
            if n.kind() == "onnx::Constant" :
                if len(n.output().uses()) == 0:
                    warnings.warn(f"Unused constant left: {n}")
                    continue
                # Skip constant folded initialzers
                if _unique_id(n.output()) in self.attrs:
                    continue
            for i in n.inputs():
                if self.is_self(i):
                    continue
                if i.node() is not None and i.node() in none_nodes:
                    continue
                if _unique_id(i) in self.attrs and _unique_id(i) not in onnx_vars:
                    k: ONNXValueID = self.attrs[_unique_id(i)]
                    assert isinstance(self.vars[k], torch.Tensor)
                    t: torch.Tensor = cast(torch.Tensor, self.vars[k])
                    onnx_vars[_unique_id(i)] = _tensor_to_proto(t, name=k)
                    register_val_name(_unique_id(i), value_name(i), shadow=True)
                    continue
                if _unique_id(i) not in val_tab:
                    register_val_name(_unique_id(v), value_name(i))

            for o in n.outputs():
                if _unique_id(o) not in val_tab:
                    register_val_name(_unique_id(o), value_name(o), shadow=True)

            def assign_onnx_values(
                onnx_values: List[str],
                prefix: str,
                torch_values: Iterator[torch._C.Value],
            ) -> None:
                assert len(onnx_values) == 0
                for v in torch_values:
                    if v.node() is not None and v.node() in none_nodes:
                        onnx_values.append("")
                        continue
                    k: ONNXValueID = val_tab.get(_unique_id(v), value_name(v))
                    if _unique_id(v) not in val_tab:
                        register_val_name(_unique_id(v), k)
                    onnx_values.append(k)

            new_nd = onnx.NodeProto()
            new_nd.name = node_name(n)
            new_nd.op_type = n.kind().split("::")[-1]
            if n.kind() == "prim::If":
                if n in self.node_doc_string:
                    new_nd.doc_string = f"""## Symbolic node
{n}
{self.node_doc_string[n]}"""
                blocks: List[torch._C.Block] = list(n.blocks())
                assert len(blocks) == 2
                for attr_name, block in zip(["then_branch", "else_branch"], blocks):
                    sub_g = block2subgraph(f"{new_nd.name}_{attr_name}", block, new_nd.doc_string)
                    new_nd.attribute.append(onnx.helper.make_attribute(attr_name, sub_g))
            else:
                assert len(list(n.blocks())) == 0, f"Node with block needs to be handled separately: {n}"
                if n in self.node_doc_string:
                    new_nd.doc_string = self.node_doc_string[n]
                for attr_name in n.attributeNames():
                    if n.kindOf(attr_name) == "t":
                        attr = onnx.helper.make_attribute(attr_name, _tensor_to_proto(n.t(attr_name)))
                    else:
                        attr = onnx.helper.make_attribute(attr_name, n[attr_name])
                    new_nd.attribute.append(attr)
            assign_onnx_values(new_nd.input, new_nd.name, n.inputs())
            assign_onnx_values(new_nd.output, new_nd.name, n.outputs())
            onnx_nodes.append(new_nd)

        return onnx_nodes, onnx_vars, val_tab
Esempio n. 10
0
 def gen_aten_node(g: torch._C.Graph, *inputs: Any) -> Union[torch._C.Value, Sequence[torch._C.Value]]:
     ret = g.op("ATen", *inputs, outputs=len(list(n.outputs())))
     v: torch._C.Value = cast(torch._C.Value, ret) if n.outputsSize() == 1 else cast(Sequence[torch._C.Value], ret)[-1]
     v.node().copyAttributes(n)
     v.node().s_("operator", n.kind().split("::")[-1])
     return ret
Esempio n. 11
0
 def gen_seq(g: torch._C.Graph, *args: Any) -> torch._C.Value:
     if len(args) == 0:
         return cast(torch._C.Value, g.op("SequenceEmpty"))  # TODO(twata): Set dtype attribute
     else:
         return cast(torch._C.Value, g.op("SequenceConstruct", *args))
Esempio n. 12
0
    def optimize_torch(self, graph: torch._C.Graph) -> torch._C.Graph:
        self.run_jit_pass(torch._C._jit_pass_inline_fork_wait, graph)  # type: ignore[attr-defined]
        if self.torch_constant_prop:
            self.run_jit_pass(torch._C._jit_pass_constant_propagation, graph)  # type: ignore[attr-defined]

        # _split_tensor_list_constants(graph, graph)
        # run dce to eliminate dead parts of the graph that might have been
        # left behind by things like symbolic_override
        self.run_jit_pass(torch._C._jit_pass_dce, graph)

        self.run_jit_pass(torch._C._jit_pass_canonicalize_graph_fuser_ops, graph)  # type: ignore[attr-defined]
        torch._C._jit_pass_peephole(graph, True)  # type: ignore[attr-defined]
        self.run_jit_pass(torch._C._jit_pass_fuse_addmm, graph)  # type: ignore[attr-defined]

        torch._C._jit_pass_peephole(graph, True)  # type: ignore[attr-defined]
        torch._C._jit_pass_lower_all_tuples(graph)  # type: ignore[attr-defined]
        # in _jit_pass_onnx, symbolic functions are called for each node for conversion.
        # However, there are nodes that cannot be converted without additional context.
        # For example, the number of outputs from split
        # (and whether it is static or dynamic) is unknown
        # until the point where it is unpacked by listUnpack node.
        # This pass does a preprocess, and prepares the nodes such that enough
        # context can be received
        # by the symbolic function.
        # torch._C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, None)
        torch._C._jit_pass_onnx_preprocess(graph)  # type: ignore[attr-defined]

        # onnx does not support tuples, so try to remove them
        torch._C._jit_pass_lint(graph)

        # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
        torch._C._jit_pass_prepare_division_for_onnx(graph)  # type: ignore[attr-defined]

        torch._C._jit_pass_onnx_remove_print(graph)  # type: ignore[attr-defined]
        torch._C._jit_pass_onnx_preprocess_caffe2(graph)  # type: ignore[attr-defined]

        if self.operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
            sym_hel._quantized_ops.clear()
            # Unpack quantized weights for conv and linear ops and insert into graph.
            torch._C._jit_pass_onnx_unpack_quantized_weights(graph, self.vars)  # type: ignore[attr-defined]
            # Insert permutes before and after each conv op to ensure correct order.
            torch._C._jit_pass_onnx_quantization_insert_permutes(graph, self.vars)  # type: ignore[attr-defined]

            # Find consecutive permutes that are no-ops and remove them.
            torch._C._jit_pass_custom_pattern_based_rewrite_graph(  # type: ignore[attr-defined]
                """
            graph(%Pi):
                %Pq = quantized::nhwc2nchw(%Pi)
                %Pr = quantized::nchw2nhwc(%Pq)
                return (%Pr)""",
                """
            graph(%Ri):
                return (%Ri)""",
                graph,
            )

        # onnx only supports tensors, so we turn all out number types into tensors
        torch._C._jit_pass_erase_number_types(graph)  # type: ignore[attr-defined]

        if self.input_names is not None:
            input_names = self.input_names.copy()
            if self.self_id is not None:
                input_names.insert(0, cast(str, self.self_name))
            assert len(list(graph.inputs())) == len(input_names)
            inputs = list(graph.inputs())
            for idx, n in enumerate(input_names):
                inputs[idx].setDebugName(n)
        torch._C._jit_pass_onnx_set_dynamic_input_shape(  # type: ignore[attr-defined]
            graph, self.dynamic_axes or {}, input_names or []
        )

        return graph