コード例 #1
0
ファイル: export.py プロジェクト: pfnet/pytorch-pfn-extras
 def gen_concat(g: torch._C.Graph, *args: Any) -> torch._C.Value:
     seq: List[torch._C.Value] = []
     for i in args:
         if i.type().kind() == "IntType" or len(i.type().sizes()) == 0:
             seq.append(
                 sym_hel._unsqueeze_helper(g, i, axes_i=[0])  # type: ignore[no-untyped-call,call-arg]
             )
         else:
             seq.append(i)
     return cast(torch._C.Value, g.op("Concat", *seq, axis_i=0))
コード例 #2
0
ファイル: export.py プロジェクト: pfnet/pytorch-pfn-extras
 def gen_const(g: torch._C.Graph, value: Any = None) -> torch._C.Value:
     c = cast(torch._C.Value, g.op("Constant"))
     if n.kindOf("value") == "ival":
         ival = n.output().toIValue()
         if isinstance(ival, list) and not isinstance(ival[0], (int, float)):
             vals: List[torch._C.Value] = []
             for i in ival:
                 if isinstance(i, torch.Tensor):
                     vals.append(cast(torch._C.Value, g.op("prim::Constant", value_t=i)))
                 else:
                     assert i is None
                     vals.append(cast(torch._C.Value, g.op("prim::Constant")))
                     vals[-1].setType(torch._C.NoneType.get())
             c = cast(torch._C.Value, g.op("prim::ListConstruct"))
             for v in vals:
                 c.node().addInput(v)
         else:
             c.node().t_("value", torch.tensor(ival))
     else:
         c.node().copyAttributes(n)
     return c
コード例 #3
0
ファイル: export.py プロジェクト: pfnet/pytorch-pfn-extras
 def gen_aten_node(g: torch._C.Graph, *inputs: Any) -> Union[torch._C.Value, Sequence[torch._C.Value]]:
     ret = g.op("ATen", *inputs, outputs=len(list(n.outputs())))
     v: torch._C.Value = cast(torch._C.Value, ret) if n.outputsSize() == 1 else cast(Sequence[torch._C.Value], ret)[-1]
     v.node().copyAttributes(n)
     v.node().s_("operator", n.kind().split("::")[-1])
     return ret
コード例 #4
0
ファイル: export.py プロジェクト: pfnet/pytorch-pfn-extras
 def gen_seq(g: torch._C.Graph, *args: Any) -> torch._C.Value:
     if len(args) == 0:
         return cast(torch._C.Value, g.op("SequenceEmpty"))  # TODO(twata): Set dtype attribute
     else:
         return cast(torch._C.Value, g.op("SequenceConstruct", *args))