def gen_concat(g: torch._C.Graph, *args: Any) -> torch._C.Value: seq: List[torch._C.Value] = [] for i in args: if i.type().kind() == "IntType" or len(i.type().sizes()) == 0: seq.append( sym_hel._unsqueeze_helper(g, i, axes_i=[0]) # type: ignore[no-untyped-call,call-arg] ) else: seq.append(i) return cast(torch._C.Value, g.op("Concat", *seq, axis_i=0))
def gen_const(g: torch._C.Graph, value: Any = None) -> torch._C.Value: c = cast(torch._C.Value, g.op("Constant")) if n.kindOf("value") == "ival": ival = n.output().toIValue() if isinstance(ival, list) and not isinstance(ival[0], (int, float)): vals: List[torch._C.Value] = [] for i in ival: if isinstance(i, torch.Tensor): vals.append(cast(torch._C.Value, g.op("prim::Constant", value_t=i))) else: assert i is None vals.append(cast(torch._C.Value, g.op("prim::Constant"))) vals[-1].setType(torch._C.NoneType.get()) c = cast(torch._C.Value, g.op("prim::ListConstruct")) for v in vals: c.node().addInput(v) else: c.node().t_("value", torch.tensor(ival)) else: c.node().copyAttributes(n) return c
def gen_aten_node(g: torch._C.Graph, *inputs: Any) -> Union[torch._C.Value, Sequence[torch._C.Value]]: ret = g.op("ATen", *inputs, outputs=len(list(n.outputs()))) v: torch._C.Value = cast(torch._C.Value, ret) if n.outputsSize() == 1 else cast(Sequence[torch._C.Value], ret)[-1] v.node().copyAttributes(n) v.node().s_("operator", n.kind().split("::")[-1]) return ret
def gen_seq(g: torch._C.Graph, *args: Any) -> torch._C.Value: if len(args) == 0: return cast(torch._C.Value, g.op("SequenceEmpty")) # TODO(twata): Set dtype attribute else: return cast(torch._C.Value, g.op("SequenceConstruct", *args))