Esempio n. 1
0
def test_basic():
    x = np.array([1, 2, 3, 4, 5])
    n = cde.Tensor(x)
    arr = np.array(n, copy=False)
    arr[0] = 0
    x = np.array([0, 2, 3, 4, 5])

    assert np.array_equal(x, arr)
    assert n.type() == cde.DataType("int64")

    arr2 = n.as_array()
    arr[0] = 2
    x = np.array([2, 2, 3, 4, 5])
    assert np.array_equal(x, arr2)
    assert n.type() == cde.DataType("int64")
    assert arr.__array_interface__['data'] == arr2.__array_interface__['data']
Esempio n. 2
0
def mstypelist_to_detypelist(type_list):
    """
    Get list[de type] corresponding to list[mindspore.dtype].

    Args:
        type_list (list[mindspore.dtype]): a list of MindSpore's dtype.

    Returns:
        The list of de data type.
    """
    for index, _ in enumerate(type_list):
        if type_list[index] is not None:
            type_list[index] = mstype_to_detype(type_list[index])
        else:
            type_list[index] = cde.DataType("")

    return type_list
Esempio n. 3
0
def mstype_to_detype(type_):
    """
    Get de data type corresponding to mindspore dtype.

    Args:
        type_ (mindspore.dtype): MindSpore's dtype.

    Returns:
        The data type of de.
    """
    if not isinstance(type_, typing.Type):
        raise NotImplementedError()

    return {
        mstype.bool_: cde.DataType("bool"),
        mstype.int8: cde.DataType("int8"),
        mstype.int16: cde.DataType("int16"),
        mstype.int32: cde.DataType("int32"),
        mstype.int64: cde.DataType("int64"),
        mstype.uint8: cde.DataType("uint8"),
        mstype.uint16: cde.DataType("uint16"),
        mstype.uint32: cde.DataType("uint32"),
        mstype.uint64: cde.DataType("uint64"),
        mstype.float16: cde.DataType("float16"),
        mstype.float32: cde.DataType("float32"),
        mstype.float64: cde.DataType("float64"),
        mstype.string: cde.DataType("string"),
    }[type_]