def test_array_cast(): assert isinstance(array_cast(np.array([1.5, 1.7]), type_to_abstract(i64)), np.ndarray) assert (array_cast(np.array([1.5, 1.7]), type_to_abstract(i64))).dtype == np.dtype(np.int64) assert isinstance(array_cast(np.array([1.5, 1.7]), type_to_abstract(f16)), np.ndarray) assert (array_cast(np.array([1.5, 1.]), type_to_abstract(f16))).dtype == np.dtype(np.float16)
def test_merge_from_types(): a = T([S(1), S(t=ty.Int[64])]) t1 = type_to_abstract(typing.Tuple) t2 = type_to_abstract(typing.Tuple[ty.Int[64], ty.Int[64]]) t3 = type_to_abstract(typing.Tuple[ty.Int[64]]) assert amerge(t1, a, forced=True) is t1 assert amerge(t2, a, forced=True) is t2 with pytest.raises(MyiaTypeError): amerge(t3, a, forced=True)
def test_type_to_abstract(): assert type_to_abstract(int) is S(t=ty.Int[64]) assert type_to_abstract(float) is S(t=ty.Float[64]) assert type_to_abstract(bool) is S(t=ty.Bool) assert (type_to_abstract(typing.List) is U(type_to_abstract(Empty), type_to_abstract(Cons))) assert type_to_abstract(typing.Tuple) is T(ANYTHING)
def test_annotation_merge(): with pytest.raises(MyiaTypeError): annotation_merge( AbstractUnion( [AbstractScalar({TYPE: i16}), AbstractScalar({TYPE: f32})]), AbstractScalar({TYPE: i32}), ) scalar = AbstractScalar({TYPE: f32}) union = AbstractUnion([AbstractScalar({TYPE: i16}), scalar]) assert annotation_merge(union, scalar) is scalar assert annotation_merge(union, scalar, forced=True) is union generic_list_type = type_to_abstract(list) specific_list_type = to_abstract([1, 2]) assert isinstance(generic_list_type, AbstractUnion) assert (annotation_merge(generic_list_type, specific_list_type, forced=True) is generic_list_type) assert isinstance(annotation_merge(generic_list_type, specific_list_type), AbstractADT)
def test_scalar_cast(): assert isinstance(scalar_cast(1.5, type_to_abstract(i64)), np.int64) assert isinstance(scalar_cast(1.5, type_to_abstract(f16)), np.float16)
def test_type_to_abstract(): assert type_to_abstract(bool) is S(t=ty.Bool) assert type_to_abstract(typing.List) is L(ANYTHING) assert type_to_abstract(typing.Tuple) is T(ANYTHING)
def Ty(t): t = t if t is ANYTHING else type_to_abstract(t) return AbstractType(t)
def Ty(t): """Generate a symbolic type.""" t = t if t is ANYTHING else type_to_abstract(t) return AbstractType(t)
def test_prim_unsafe_static_cast(): assert unsafe_static_cast(1234, type_to_abstract(float)) == 1234