Exemple #1
0
def test_scalar_cast_grad():
    """ test_scalar_cast_grad """
    input_x = 255.5
    input_t = get_py_obj_dtype(ms.int8)

    def fx_cast(x):
        output = F.scalar_cast(x, input_t)
        return output

    gfn = C.grad(fx_cast)(input_x)
    expect_dx = 1
    assert gfn == expect_dx
Exemple #2
0
def test_scalar_cast():
    """ test_scalar_cast """
    input_x = 8.5
    input_t = get_py_obj_dtype(ms.int64)

    @ms_function
    def fn_cast(x, t):
        output = F.scalar_cast(x, t)
        return output

    expect_value = 8
    z = fn_cast(input_x, input_t)
    assert z == expect_value
Exemple #3
0
def test_dtype():
    """test_dtype"""
    x = 1.5
    me_type = dtype.get_py_obj_dtype(x)
    assert me_type == ms.float64
    me_type = dtype.get_py_obj_dtype(type(x))
    assert me_type == ms.float64

    x = 100
    me_type = dtype.get_py_obj_dtype(type(x))
    assert me_type == ms.int64
    me_type = dtype.get_py_obj_dtype(x)
    assert me_type == ms.int64

    x = False
    me_type = dtype.get_py_obj_dtype(type(x))
    assert me_type == ms.bool_
    me_type = dtype.get_py_obj_dtype(x)
    assert me_type == ms.bool_

    # support str
    # x = "string type"

    x = [1, 2, 3]
    me_type = dtype.get_py_obj_dtype(x)
    assert me_type == ms.list_
    me_type = dtype.get_py_obj_dtype(type(x))
    assert me_type == ms.list_

    x = (2, 4, 5)
    me_type = dtype.get_py_obj_dtype(x)
    assert me_type == ms.tuple_
    me_type = dtype.get_py_obj_dtype(type(x))
    assert me_type == ms.tuple_

    y = Foo(3)
    me_type = dtype.get_py_obj_dtype(y.x)
    assert me_type == ms.int64
    me_type = dtype.get_py_obj_dtype(type(y.x))
    assert me_type == ms.int64

    y = Foo(3.1)
    me_type = dtype.get_py_obj_dtype(y.x)
    assert me_type == ms.float64
    me_type = dtype.get_py_obj_dtype(type(y.x))
    assert me_type == ms.float64

    fields = get_class_attrib_types(y)
    assert len(fields) == 1
    me_type = dtype.get_py_obj_dtype(fields[0])
    assert me_type == ms.int64

    fields = get_class_attrib_types(Foo)
    assert len(fields) == 1
    me_type = dtype.get_py_obj_dtype(fields[0])
    assert me_type == ms.int64

    with pytest.raises(NotImplementedError):
        x = 1.5
        dtype.get_py_obj_dtype(type(type(x)))
def typeof(x):
    """Implement typeof."""
    return get_py_obj_dtype(x)