Beispiel #1
0
    def handle_list_construct(self, g: torch._C.Graph, n: torch._C.Node) -> None:
        # Concat if int type input
        is_integer_output: bool = n.output().type().getElementType().kind() == "IntType"
        if len(list(n.inputs())) > 0 and is_integer_output:

            def gen_concat(g: torch._C.Graph, *args: Any) -> torch._C.Value:
                seq: List[torch._C.Value] = []
                for i in args:
                    if i.type().kind() == "IntType" or len(i.type().sizes()) == 0:
                        seq.append(
                            sym_hel._unsqueeze_helper(g, i, axes_i=[0])  # type: ignore[no-untyped-call,call-arg]
                        )
                    else:
                        seq.append(i)
                return cast(torch._C.Value, g.op("Concat", *seq, axis_i=0))

            self.run_symbolic_function(g, n, gen_concat)
        else:

            def gen_seq(g: torch._C.Graph, *args: Any) -> torch._C.Value:
                if len(args) == 0:
                    return cast(torch._C.Value, g.op("SequenceEmpty"))  # TODO(twata): Set dtype attribute
                else:
                    return cast(torch._C.Value, g.op("SequenceConstruct", *args))

            self.run_symbolic_function(g, n, gen_seq)
Beispiel #2
0
 def handle_getattr(self, g: torch._C.Graph, n: torch._C.Node) -> None:
     if self.is_self(n.input()) or self.attrs[_unique_id(n.input())] == _ppe_ignore_scope:
         self.attrs[_unique_id(n.output())] = ONNXValueID(n.s("name"))
     else:
         self.attrs[_unique_id(n.output())] = ONNXValueID(
             "%s.%s"
             % (
                 self.attrs[_unique_id(n.input())],
                 n.s("name"),
             )
         )
     var_name = self.attrs[_unique_id(n.output())]
     if var_name in self.vars:
         assert isinstance(self.vars[var_name], torch.Tensor)
         n.output().inferTypeFrom(cast(torch.Tensor, self.vars[var_name]))
Beispiel #3
0
    def handle_constant(self, g: torch._C.Graph, n: torch._C.Node) -> None:
        # Skip None constant node
        if n.mustBeNone():
            return

        def gen_const(g: torch._C.Graph, value: Any = None) -> torch._C.Value:
            c = cast(torch._C.Value, g.op("Constant"))
            if n.kindOf("value") == "ival":
                ival = n.output().toIValue()
                if isinstance(ival, list) and not isinstance(ival[0], (int, float)):
                    vals: List[torch._C.Value] = []
                    for i in ival:
                        if isinstance(i, torch.Tensor):
                            vals.append(cast(torch._C.Value, g.op("prim::Constant", value_t=i)))
                        else:
                            assert i is None
                            vals.append(cast(torch._C.Value, g.op("prim::Constant")))
                            vals[-1].setType(torch._C.NoneType.get())
                    c = cast(torch._C.Value, g.op("prim::ListConstruct"))
                    for v in vals:
                        c.node().addInput(v)
                else:
                    c.node().t_("value", torch.tensor(ival))
            else:
                c.node().copyAttributes(n)
            return c

        self.run_symbolic_function(g, n, gen_const)
Beispiel #4
0
    def run_symbolic_function(self, g: torch._C.Graph, n: torch._C.Node, sym_func: Callable) -> None:
        attrs: Dict[str, Any] = {}
        for a in n.attributeNames():
            if a == "value" and n.kindOf("value") == "ival":
                attrs[a] = n.output().toIValue()
            else:
                attrs[a] = n[a]
        if "inplace" in attrs:
            del attrs["inplace"]
        node_inputs = list(n.inputs())
        if n.kind() == "prim::PythonOp":
            node_inputs.extend(n.scalar_args())
        sym_outs = _to_tuple_if_not_sequence(sym_func(g, *node_inputs, **attrs))
        assert len(sym_outs) == n.outputsSize(), f"{sym_outs}: {len(sym_outs)} vs {n.outputsSize()}"

        def list_added_nodes() -> List[torch._C.Node]:
            start_vals: Set[torch._C.Value] = set(list(n.inputs()))
            ret: Set[torch._C.Node] = set()
            target_vals: List[torch._C.Value] = list(sym_outs)
            for i in sym_outs:
                if i in start_vals:
                    continue
                ret.add(i.node())
                target_vals.extend(list(i.node().inputs()))
            while len(target_vals) > 0:
                i = target_vals.pop()
                if i in start_vals:
                    continue
                ret.add(i.node())
                start_vals.add(i)
                target_vals.extend(list(i.node().inputs()))
            return list(ret)

        sym_nodes: List[torch._C.Node] = list_added_nodes()

        self.log(f"Converting node {n.kind()}", n)
        if len(sym_nodes) > 0:
            self.log(f"Converted node {n.kind()}", "\n".join([str(i) for i in sym_nodes]))

        # Generate doc string before old node lifetime ends
        for sym_nd in sym_nodes:
            if not self.strip_doc_string:
                self.node_doc_string[sym_nd] = onnx_node_doc_string(sym_nd, n)
            self.node_scope[sym_nd] = n.scopeName()

        # Replace uses of old node output with symbolic outputs
        for old_out, new_out in zip(n.outputs(), sym_outs):
            old_out.replaceAllUsesWith(new_out)
            assert len(old_out.uses()) == 0
            new_out.copyMetadata(old_out)
Beispiel #5
0
 def symbolic_function(self, n: torch._C.Node) -> Optional[Callable]:
     ns, op = n.kind().split("::")
     if op.endswith("_"):  # For inplace op
         op = op[:-1]
     if ns == "prim" and op == "PythonOp":
         pyobj = n.pyobj()
         if issubclass(pyobj.__self__, torch.autograd.Function):
             pyobj = pyobj.__self__
         assert issubclass(pyobj, torch.autograd.Function)
         assert hasattr(pyobj, "symbolic"), f"symbolic method not supported in {pyobj}"
         # TODO(twata): Use repr(pyobj) in scope name or doc_string
         return cast(Callable, pyobj.symbolic)
     else:
         if ns == "prim":
             op = f"prim_{op}"
         if sym_reg.is_registered_op(op, "", self.opset_version):  # type: ignore[no-untyped-call]
             return cast(
                 Callable, sym_reg.get_registered_op(op, "", self.opset_version)  # type: ignore[no-untyped-call]
             )
         else:
             return None
def _get_module_instance(
        node: torch._C.Node,
        node_name_to_module: Dict[str, torch.nn.Module]) -> torch.nn.Module:
    """
    Get the torch.nn.Module referenced by the node.
    :param node: trace graph node
    :param node_name_to_module: dictionary of module index by output_name referenced in the sub-graph
    :return: torch module corresponding to the node
    """
    input_name: str = node.input().debugName()
    attributes = _get_attribute_name(node)
    model = node_name_to_module[input_name]
    sub_model = getattr(model, attributes['name'])
    return sub_model
Beispiel #7
0
    def handle_if(self, g: torch._C.Graph, n: torch._C.Node) -> None:
        # Generated onnx node doc string should be added later since DCE isn't completed yet
        doc_str: str = f"""
## Original node
{n}
## Scope
{n.scopeName()}
## Source Range
```
{n.sourceRange()}
```
"""

        # If node will reused to keep graph lint happy
        for b in n.blocks():
            block_nodes = list(b.nodes())
            for b_n in block_nodes:
                self.generate_onnx_node(cast(torch._C.Graph, b), b_n)

        if not self.strip_doc_string:
            self.node_doc_string[n] = doc_str

        # Move to last of graph to keep the execution order of node
        n.moveBefore(g.return_node())
def _get_attribute_name(node: torch._C.Node) -> Dict[str, str]:
    """
    Retrieve the attributes associated with the graph node
    :param node: trace graph node
    :return: a dictionary of attributes associated with the node
    """
    attributes = {}
    # node description has pseudo-code of the form  '... torch_mangle_2.Module = prim::GetAttr[name="fc"](%self.1)'
    # for the above example attributeNames() iterator should return a string 'name'
    node_desc = str(node)
    for attribute_name in node.attributeNames():
        pattern = attribute_name + '="'
        if pattern in node_desc:
            attributes[attribute_name] = node_desc.split(pattern)[1].split(
                '"')[0]
    return attributes
Beispiel #9
0
    def generate_onnx_node(self, g: torch._C.Graph, n: torch._C.Node) -> None:
        node_kind: str = n.kind()
        if node_kind in self.handler:
            self.handler[node_kind](self, g, n)
            return

        f: Optional[Callable] = self.symbolic_function(n)
        if self.operator_export_type in [OperatorExportTypes.ONNX_ATEN, OperatorExportTypes.ONNX_FALLTHROUGH] or (
            self.operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and f is None
        ):
            def gen_aten_node(g: torch._C.Graph, *inputs: Any) -> Union[torch._C.Value, Sequence[torch._C.Value]]:
                ret = g.op("ATen", *inputs, outputs=len(list(n.outputs())))
                v: torch._C.Value = cast(torch._C.Value, ret) if n.outputsSize() == 1 else cast(Sequence[torch._C.Value], ret)[-1]
                v.node().copyAttributes(n)
                v.node().s_("operator", n.kind().split("::")[-1])
                return ret

            f = gen_aten_node
        assert f is not None, f"Symbolic function for {n.kind()} not found"
        self.run_symbolic_function(g, n, f)
Beispiel #10
0
def _node_get(node: torch._C.Node, key: str):
    """Gets attributes of a node which is polymorphic over return type."""
    sel = node.kindOf(key)
    return getattr(node, sel)(key)
Beispiel #11
0
 def node_name(n: torch._C.Node) -> str:
     nonlocal node_name_counter
     op = n.kind().split("::")[-1]
     node_name_counter += 1
     return f"{op}_{node_name_counter - 1}"