def test_scalar_cast_grad(): """ test_scalar_cast_grad """ input_x = 255.5 input_t = get_py_obj_dtype(ms.int8) def fx_cast(x): output = F.scalar_cast(x, input_t) return output gfn = C.grad(fx_cast)(input_x) expect_dx = 1 assert gfn == expect_dx
def test_scalar_cast(): """ test_scalar_cast """ input_x = 8.5 input_t = get_py_obj_dtype(ms.int64) @ms_function def fn_cast(x, t): output = F.scalar_cast(x, t) return output expect_value = 8 z = fn_cast(input_x, input_t) assert z == expect_value
def test_dtype(): """test_dtype""" x = 1.5 me_type = dtype.get_py_obj_dtype(x) assert me_type == ms.float64 me_type = dtype.get_py_obj_dtype(type(x)) assert me_type == ms.float64 x = 100 me_type = dtype.get_py_obj_dtype(type(x)) assert me_type == ms.int64 me_type = dtype.get_py_obj_dtype(x) assert me_type == ms.int64 x = False me_type = dtype.get_py_obj_dtype(type(x)) assert me_type == ms.bool_ me_type = dtype.get_py_obj_dtype(x) assert me_type == ms.bool_ # support str # x = "string type" x = [1, 2, 3] me_type = dtype.get_py_obj_dtype(x) assert me_type == ms.list_ me_type = dtype.get_py_obj_dtype(type(x)) assert me_type == ms.list_ x = (2, 4, 5) me_type = dtype.get_py_obj_dtype(x) assert me_type == ms.tuple_ me_type = dtype.get_py_obj_dtype(type(x)) assert me_type == ms.tuple_ y = Foo(3) me_type = dtype.get_py_obj_dtype(y.x) assert me_type == ms.int64 me_type = dtype.get_py_obj_dtype(type(y.x)) assert me_type == ms.int64 y = Foo(3.1) me_type = dtype.get_py_obj_dtype(y.x) assert me_type == ms.float64 me_type = dtype.get_py_obj_dtype(type(y.x)) assert me_type == ms.float64 fields = get_class_attrib_types(y) assert len(fields) == 1 me_type = dtype.get_py_obj_dtype(fields[0]) assert me_type == ms.int64 fields = get_class_attrib_types(Foo) assert len(fields) == 1 me_type = dtype.get_py_obj_dtype(fields[0]) assert me_type == ms.int64 with pytest.raises(NotImplementedError): x = 1.5 dtype.get_py_obj_dtype(type(type(x)))
def typeof(x): """Implement typeof.""" return get_py_obj_dtype(x)