def call_impl(self, env, a, axis, dtype, out): assert axis.is_none() # TODO(hamaji): Not supported yet. assert dtype.is_none() # TODO(hamaji): Not supported yet. assert out.is_none() # TODO(hamaji): Not supported yet. # さらにさらに、入力は1次元のTensorである、と仮定してしまいます # 戻り値は入力に依らずテンソルらしい # TODO(satos) さすがに仮定がきつい v = a.to_tensor(env) # これ戻り値がテンソルでなくSequenceなら、SplitAxisみたいにかっこよく書けるはず """ a = new_tensor() env.addnode( 'Flatten', inputs=[v.name],outputs[a.name], axis=0 ) v = a a = new_tensor() env.addnode( 'Squeeze', inputs=[v.name],outputs[a.name], axes=[0] ) """ ls = env.calc( 'ChainerGenericLen', inputs=[v.name], ) def dummy(): return "dummy_" + new_tensor().name localenv = Env(env.module) cnt = new_tensor() cond = new_tensor() s = new_tensor() gtx = new_tensor() tx = localenv.calc( "ChainerGenericGetItem", inputs=[gtx.name, cnt.name], ) ts = localenv.calc( "Add", inputs=[tx.name, s.name], ) ts2 = localenv.calc("Identity", inputs=[ts.name]) zero = totensor(0, env) res = new_tensor() env.addnode('Loop', inputs=[ls.name, "", v.name, zero.name], outputs=[dummy(), dummy(), res.name], body=utils.make_graph(localenv.nodes, "Cumsum_subgraph", [cnt, cond, gtx, s], [cond, gtx, ts, ts2])) return res
def to_tensor(self, env: 'utils.Env', dtype: type = None) -> onnx.ValueInfoProto: if self.is_py: self.const_value = Value(self.value) # TODO(hamaji): Rewrite `totensor` to convert a Python # list to a tensor. self.value = utils.totensor(self.value, env, dtype=dtype) self.is_py = False else: if self.is_sequence(): self.value = env.calc('ChainerSequenceStack', inputs=[self.value.name]) self.is_py = False if dtype is not None: dt = utils.onnx_dtype(dtype) self.value = env.calc('Cast', inputs=[self.value.name], to=dt) self.value.type.tensor_type.elem_type = dt assert self.is_tensor() return self.value