Exemplo n.º 1
0
    def slice2list(self):
        # TODO(hamaji): Use 2**63-1 instead.
        int_max = 2**31 - 1
        if isinstance(self, gast.Slice):
            assert self.step is None

            def f(x, v):
                if x is None:
                    return Value(np.array([v])).to_tensor(env)
                x = eval_ast(x, env)
                if x.is_tensor():
                    return unsqueeze(x.value)
                else:
                    return Value(np.array([x.value])).to_tensor(env)

            lower = f(self.lower, 0)
            upper = f(self.upper, int_max)
            squeeze = [False]
        elif isinstance(self, gast.Index):
            idx = eval_ast(self.value, env)
            if isinstance(idx.value, tuple):  # ここにTupleが来うる
                # TODO(satos) もっとうまくやったほうがいいかも
                vs = [
                    gast.Index(gast.NameConstant(value=v)) for v in idx.value
                ]
                lower, upper, squeeze = slice2list(gast.ExtSlice(dims=vs))
            elif not idx.is_py:
                lower = unsqueeze(idx.value)
                ot = totensor(1, env)
                upper = env.calc(
                    "Add",
                    inputs=[idx.to_tensor(env).name, ot.name],
                )
                upper = unsqueeze(upper)
                squeeze = [True]
            else:
                lower = totensor(np.array([idx.value]), env)
                upper_value = idx.value + 1 if idx.value != -1 else int_max
                upper = totensor(np.array([upper_value]), env)
                squeeze = [True]
        elif isinstance(self, gast.ExtSlice):
            ds = list(map(slice2list, self.dims))
            lower = _concat(
                tuple(map(lambda x: castto(x[0], TensorProto.INT64, env), ds)),
                0, env)
            upper = _concat(
                tuple(map(lambda x: castto(x[1], TensorProto.INT64, env), ds)),
                0, env)
            squeeze = sum(map(lambda x: x[2], ds), [])
        else:
            raise Exception(self, " is not Python slice")

        return lower, upper, squeeze
Exemplo n.º 2
0
    def to_tensor(self, env: 'utils.Env',
                  dtype: type = None) -> onnx.ValueInfoProto:
        if self.is_py:
            self.const_value = Value(self.value)
            # TODO(hamaji): Rewrite `totensor` to convert a Python
            # list to a tensor.
            self.value = utils.totensor(self.value, env, dtype=dtype)
            self.is_py = False
        else:
            if self.is_sequence():
                self.value = env.calc('ChainerSequenceStack',
                                      inputs=[self.value.name])
                self.is_py = False

            if dtype is not None:
                dt = utils.onnx_dtype(dtype)
                self.value = env.calc(
                    'Cast',
                    inputs=[self.value.name],
                    to=dt
                )
                self.value.type.tensor_type.elem_type = dt

        assert self.is_tensor()
        return self.value
Exemplo n.º 3
0
    def call_impl(self, env, a, axis, dtype, out):
        assert axis.is_none()  # TODO(hamaji): Not supported yet.
        assert dtype.is_none()  # TODO(hamaji): Not supported yet.
        assert out.is_none()  # TODO(hamaji): Not supported yet.
        # さらにさらに、入力は1次元のTensorである、と仮定してしまいます
        # 戻り値は入力に依らずテンソルらしい
        # TODO(satos) さすがに仮定がきつい
        v = a.to_tensor(env)

        # これ戻り値がテンソルでなくSequenceなら、SplitAxisみたいにかっこよく書けるはず
        """
        a = new_tensor()
        env.addnode(
            'Flatten',
            inputs=[v.name],outputs[a.name],
            axis=0
        )
        v = a
        a = new_tensor()
        env.addnode(
            'Squeeze',
            inputs=[v.name],outputs[a.name],
            axes=[0]
        )
        """
        ls = env.calc(
            'ChainerGenericLen',
            inputs=[v.name],
        )

        def dummy():
            return "dummy_" + new_tensor().name

        localenv = Env(env.module)
        cnt = new_tensor()
        cond = new_tensor()
        s = new_tensor()
        gtx = new_tensor()
        tx = localenv.calc(
            "ChainerGenericGetItem",
            inputs=[gtx.name, cnt.name],
        )
        ts = localenv.calc(
            "Add",
            inputs=[tx.name, s.name],
        )
        ts2 = localenv.calc("Identity", inputs=[ts.name])

        zero = totensor(0, env)

        res = new_tensor()
        env.addnode('Loop',
                    inputs=[ls.name, "", v.name, zero.name],
                    outputs=[dummy(), dummy(), res.name],
                    body=utils.make_graph(localenv.nodes, "Cumsum_subgraph",
                                          [cnt, cond, gtx, s],
                                          [cond, gtx, ts, ts2]))

        return res