def test_basic(): x = np.array([1, 2, 3, 4, 5]) n = cde.Tensor(x) arr = np.array(n, copy=False) arr[0] = 0 x = np.array([0, 2, 3, 4, 5]) assert np.array_equal(x, arr) assert n.type() == cde.DataType("int64") arr2 = n.as_array() arr[0] = 2 x = np.array([2, 2, 3, 4, 5]) assert np.array_equal(x, arr2) assert n.type() == cde.DataType("int64") assert arr.__array_interface__['data'] == arr2.__array_interface__['data']
def mstypelist_to_detypelist(type_list): """ Get list[de type] corresponding to list[mindspore.dtype]. Args: type_list (list[mindspore.dtype]): a list of MindSpore's dtype. Returns: The list of de data type. """ for index, _ in enumerate(type_list): if type_list[index] is not None: type_list[index] = mstype_to_detype(type_list[index]) else: type_list[index] = cde.DataType("") return type_list
def mstype_to_detype(type_): """ Get de data type corresponding to mindspore dtype. Args: type_ (mindspore.dtype): MindSpore's dtype. Returns: The data type of de. """ if not isinstance(type_, typing.Type): raise NotImplementedError() return { mstype.bool_: cde.DataType("bool"), mstype.int8: cde.DataType("int8"), mstype.int16: cde.DataType("int16"), mstype.int32: cde.DataType("int32"), mstype.int64: cde.DataType("int64"), mstype.uint8: cde.DataType("uint8"), mstype.uint16: cde.DataType("uint16"), mstype.uint32: cde.DataType("uint32"), mstype.uint64: cde.DataType("uint64"), mstype.float16: cde.DataType("float16"), mstype.float32: cde.DataType("float32"), mstype.float64: cde.DataType("float64"), mstype.string: cde.DataType("string"), }[type_]