예제 #1
0
    def call_impl(self, env, a, axis, dtype, out):
        assert axis.is_none()  # TODO(hamaji): Not supported yet.
        assert dtype.is_none()  # TODO(hamaji): Not supported yet.
        assert out.is_none()  # TODO(hamaji): Not supported yet.
        # さらにさらに、入力は1次元のTensorである、と仮定してしまいます
        # 戻り値は入力に依らずテンソルらしい
        # TODO(satos) さすがに仮定がきつい
        v = a.to_tensor(env)

        # これ戻り値がテンソルでなくSequenceなら、SplitAxisみたいにかっこよく書けるはず
        """
        a = new_tensor()
        env.addnode(
            'Flatten',
            inputs=[v.name],outputs[a.name],
            axis=0
        )
        v = a
        a = new_tensor()
        env.addnode(
            'Squeeze',
            inputs=[v.name],outputs[a.name],
            axes=[0]
        )
        """
        ls = env.calc(
            'ChainerGenericLen',
            inputs=[v.name],
        )

        def dummy():
            return "dummy_" + new_tensor().name

        localenv = Env(env.module)
        cnt = new_tensor()
        cond = new_tensor()
        s = new_tensor()
        gtx = new_tensor()
        tx = localenv.calc(
            "ChainerGenericGetItem",
            inputs=[gtx.name, cnt.name],
        )
        ts = localenv.calc(
            "Add",
            inputs=[tx.name, s.name],
        )
        ts2 = localenv.calc("Identity", inputs=[ts.name])

        zero = totensor(0, env)

        res = new_tensor()
        env.addnode('Loop',
                    inputs=[ls.name, "", v.name, zero.name],
                    outputs=[dummy(), dummy(), res.name],
                    body=utils.make_graph(localenv.nodes, "Cumsum_subgraph",
                                          [cnt, cond, gtx, s],
                                          [cond, gtx, ts, ts2]))

        return res
예제 #2
0
    def to_tensor(self,
                  env: 'utils.Env',
                  dtype: type = None) -> onnx.ValueInfoProto:
        if self.is_py:
            self.const_value = Value(self.value)
            # TODO(hamaji): Rewrite `totensor` to convert a Python
            # list to a tensor.
            self.value = utils.totensor(self.value, env, dtype=dtype)
            self.is_py = False
        else:
            if self.is_sequence():
                self.value = env.calc('ChainerSequenceStack',
                                      inputs=[self.value.name])
                self.is_py = False

            if dtype is not None:
                dt = utils.onnx_dtype(dtype)
                self.value = env.calc('Cast', inputs=[self.value.name], to=dt)
                self.value.type.tensor_type.elem_type = dt

        assert self.is_tensor()
        return self.value