コード例 #1
0
ファイル: tensor.py プロジェクト: zwxlib/MegEngine
def tensor(data: Union[list, np.ndarray] = None,
           *,
           dtype: str = None,
           device: mgb.CompNode = None,
           requires_grad: bool = None):
    r"""A helper function to create a :class:`~.Tensor` using existing data.

    :param data: an existing data array, must be Python list, NumPy array or None.
    :param dtype: target Tensor data type, one of ``("uint8", "int8", "int16", "int32", "float32", "float16")``.
    :param device: target device for Tensor storing.
    :param requires_grad: whether its gradiant will be calculated during :meth:`~.Optimizer.backward`
    """
    supported_dtypes = ("uint8", "int8", "int16", "int32", "float32",
                        "float16")
    if isinstance(data, Tensor):
        raise NotImplementedError
    if dtype is not None and np.dtype(dtype).name not in supported_dtypes:
        raise TypeError("unsupported dtype {}".format(dtype))
    if data is not None:
        if not isinstance(data, np.ndarray):
            data = np.array(data, dtype=dtype)
            # In order to accept tensor([1]),
            # Automaticlly convert to  32-bit number instead of numpy's default 64-bit when input data is not nparray.
            dtype = mgb.to_mgb_supported_dtype(data.dtype)
        if dtype is None:
            if data.dtype.name not in supported_dtypes:
                raise TypeError("unsupported dtype {}".format(data.dtype))

    device, _ = _use_default_if_none(device, None)
    shared_nd = mgb.make_shared(device, value=data, dtype=dtype)
    return Tensor(shared_nd, requires_grad=requires_grad)
コード例 #2
0
ファイル: tensor.py プロジェクト: zwxlib/MegEngine
 def __setstate__(self, state):
     data = state.pop("data")
     device = state.pop("device")
     dtype = state.pop("dtype")
     metadata = state.pop("metadata", {})
     requires_grad = metadata.pop("requires_grad", None)
     snd = mgb.make_shared(device, value=data, dtype=dtype)
     self._reset(snd, requires_grad=requires_grad)
コード例 #3
0
ファイル: tensor.py プロジェクト: ztjryg4/MegEngine
 def set_dtype(self, dtype: str = None):
     r"""Set the data type of the tensor.
     """
     if self.__val is not None:
         self.__val = mgb.make_shared(self.device,
                                      value=self.astype(dtype).numpy())
     elif self.__sym is not None:
         self.__sym = self.__sym.astype(dtype)