예제 #1
0
 def call_impl(self, env, shape, dtype, order):
     assert order.value == 'C'
     dt = utils.onnx_dtype(dtype.value)
     return env.calc('ConstantFill',
                     inputs=[shape.to_tensor(env).name],
                     input_as_shape=1,
                     dtype=dt)
예제 #2
0
 def call_impl(self, env, shape, fill_value, dtype, order):
     assert order.value == 'C'
     res = env.calc(
         'Expand',
         inputs=[fill_value.to_tensor(env).name,
                 shape.to_tensor(env).name],
     )
     if not dtype.is_none():
         dt = utils.onnx_dtype(dtype.value)
         res = castto(res, dt, env)
     return res
예제 #3
0
    def to_tensor(self,
                  env: 'utils.Env',
                  dtype: type = None) -> onnx.ValueInfoProto:
        if self.is_py:
            self.const_value = Value(self.value)
            # TODO(hamaji): Rewrite `totensor` to convert a Python
            # list to a tensor.
            self.value = utils.totensor(self.value, env, dtype=dtype)
            self.is_py = False
        else:
            if self.is_sequence():
                self.value = env.calc('ChainerSequenceStack',
                                      inputs=[self.value.name])
                self.is_py = False

            if dtype is not None:
                dt = utils.onnx_dtype(dtype)
                self.value = env.calc('Cast', inputs=[self.value.name], to=dt)
                self.value.type.tensor_type.elem_type = dt

        assert self.is_tensor()
        return self.value