def slice2list(self): # TODO(hamaji): Use 2**63-1 instead. int_max = 2**31 - 1 if isinstance(self, gast.Slice): assert self.step is None def f(x, v): if x is None: return Value(np.array([v])).to_tensor(env) x = eval_ast(x, env) if x.is_tensor(): return unsqueeze(x.value) else: return Value(np.array([x.value])).to_tensor(env) lower = f(self.lower, 0) upper = f(self.upper, int_max) squeeze = [False] elif isinstance(self, gast.Index): idx = eval_ast(self.value, env) if isinstance(idx.value, tuple): # ここにTupleが来うる # TODO(satos) もっとうまくやったほうがいいかも vs = [ gast.Index(gast.NameConstant(value=v)) for v in idx.value ] lower, upper, squeeze = slice2list(gast.ExtSlice(dims=vs)) elif not idx.is_py: lower = unsqueeze(idx.value) ot = totensor(1, env) upper = env.calc( "Add", inputs=[idx.to_tensor(env).name, ot.name], ) upper = unsqueeze(upper) squeeze = [True] else: lower = totensor(np.array([idx.value]), env) upper_value = idx.value + 1 if idx.value != -1 else int_max upper = totensor(np.array([upper_value]), env) squeeze = [True] elif isinstance(self, gast.ExtSlice): ds = list(map(slice2list, self.dims)) lower = _concat( tuple(map(lambda x: castto(x[0], TensorProto.INT64, env), ds)), 0, env) upper = _concat( tuple(map(lambda x: castto(x[1], TensorProto.INT64, env), ds)), 0, env) squeeze = sum(map(lambda x: x[2], ds), []) else: raise Exception(self, " is not Python slice") return lower, upper, squeeze
def to_tensor(self, env: 'utils.Env', dtype: type = None) -> onnx.ValueInfoProto: if self.is_py: self.const_value = Value(self.value) # TODO(hamaji): Rewrite `totensor` to convert a Python # list to a tensor. self.value = utils.totensor(self.value, env, dtype=dtype) self.is_py = False else: if self.is_sequence(): self.value = env.calc('ChainerSequenceStack', inputs=[self.value.name]) self.is_py = False if dtype is not None: dt = utils.onnx_dtype(dtype) self.value = env.calc( 'Cast', inputs=[self.value.name], to=dt ) self.value.type.tensor_type.elem_type = dt assert self.is_tensor() return self.value
def call_impl(self, env, a, axis, dtype, out): assert axis.is_none() # TODO(hamaji): Not supported yet. assert dtype.is_none() # TODO(hamaji): Not supported yet. assert out.is_none() # TODO(hamaji): Not supported yet. # さらにさらに、入力は1次元のTensorである、と仮定してしまいます # 戻り値は入力に依らずテンソルらしい # TODO(satos) さすがに仮定がきつい v = a.to_tensor(env) # これ戻り値がテンソルでなくSequenceなら、SplitAxisみたいにかっこよく書けるはず """ a = new_tensor() env.addnode( 'Flatten', inputs=[v.name],outputs[a.name], axis=0 ) v = a a = new_tensor() env.addnode( 'Squeeze', inputs=[v.name],outputs[a.name], axes=[0] ) """ ls = env.calc( 'ChainerGenericLen', inputs=[v.name], ) def dummy(): return "dummy_" + new_tensor().name localenv = Env(env.module) cnt = new_tensor() cond = new_tensor() s = new_tensor() gtx = new_tensor() tx = localenv.calc( "ChainerGenericGetItem", inputs=[gtx.name, cnt.name], ) ts = localenv.calc( "Add", inputs=[tx.name, s.name], ) ts2 = localenv.calc("Identity", inputs=[ts.name]) zero = totensor(0, env) res = new_tensor() env.addnode('Loop', inputs=[ls.name, "", v.name, zero.name], outputs=[dummy(), dummy(), res.name], body=utils.make_graph(localenv.nodes, "Cumsum_subgraph", [cnt, cond, gtx, s], [cond, gtx, ts, ts2])) return res