def test_subtype(self): from torch.utils.data._typing import issubtype basic_type = (int, str, bool, float, complex, list, tuple, dict, set, T_co) for t in basic_type: self.assertTrue(issubtype(t, t)) self.assertTrue(issubtype(t, Any)) if t == T_co: self.assertTrue(issubtype(Any, t)) else: self.assertFalse(issubtype(Any, t)) for t1, t2 in itertools.product(basic_type, basic_type): if t1 == t2 or t2 == T_co: self.assertTrue(issubtype(t1, t2)) else: self.assertFalse(issubtype(t1, t2)) T = TypeVar('T', int, str) S = TypeVar('S', bool, Union[str, int], Tuple[int, T]) # type: ignore[valid-type] types = ((int, Optional[int]), (List, Union[int, list]), (Tuple[int, str], S), (Tuple[int, str], tuple), (T, S), (S, T_co), (T, Union[S, Set])) for sub, par in types: self.assertTrue(issubtype(sub, par)) self.assertFalse(issubtype(par, sub)) subscriptable_types = { List: 1, Tuple: 2, # use 2 parameters Set: 1, Dict: 2, } for subscript_type, n in subscriptable_types.items(): for ts in itertools.combinations(types, n): subs, pars = zip(*ts) sub = subscript_type[subs] # type: ignore[index] par = subscript_type[pars] # type: ignore[index] self.assertTrue(issubtype(sub, par)) self.assertFalse(issubtype(par, sub)) # Non-recursive check self.assertTrue(issubtype(par, sub, recursive=False))
def test_compile_time(self): with self.assertRaisesRegex(TypeError, r"Expected 'Iterator' as the return"): class InvalidDP1(IterDataPipe[int]): def __iter__(self) -> str: # type: ignore[misc, override] yield 0 with self.assertRaisesRegex(TypeError, r"Expected return type of '__iter__'"): class InvalidDP2(IterDataPipe[Tuple]): def __iter__(self) -> Iterator[int]: # type: ignore[override] yield 0 with self.assertRaisesRegex(TypeError, r"Expected return type of '__iter__'"): class InvalidDP3(IterDataPipe[Tuple[int, str]]): def __iter__(self) -> Iterator[tuple]: # type: ignore[override] yield (0, ) if _generic_namedtuple_allowed: with self.assertRaisesRegex(TypeError, r"is not supported by Python typing"): class InvalidDP4(IterDataPipe["InvalidData[int]"]): # type: ignore[type-arg, misc] pass class DP1(IterDataPipe[Tuple[int, str]]): def __init__(self, length): self.length = length def __iter__(self) -> Iterator[Tuple[int, str]]: for d in range(self.length): yield d, str(d) self.assertTrue(issubclass(DP1, IterDataPipe)) dp1 = DP1(10) self.assertTrue(DP1.type.issubtype(dp1.type) and dp1.type.issubtype(DP1.type)) dp2 = DP1(5) self.assertEqual(dp1.type, dp2.type) with self.assertRaisesRegex(TypeError, r"is not a generic class"): class InvalidDP5(DP1[tuple]): # type: ignore[type-arg] def __iter__(self) -> Iterator[tuple]: # type: ignore[override] yield (0, ) class DP2(IterDataPipe[T_co]): def __iter__(self) -> Iterator[T_co]: for d in range(10): yield d # type: ignore[misc] self.assertTrue(issubclass(DP2, IterDataPipe)) dp1 = DP2() # type: ignore[assignment] self.assertTrue(DP2.type.issubtype(dp1.type) and dp1.type.issubtype(DP2.type)) dp2 = DP2() # type: ignore[assignment] self.assertEqual(dp1.type, dp2.type) class DP3(IterDataPipe[Tuple[T_co, str]]): r""" DataPipe without fixed type with __init__ function""" def __init__(self, datasource): self.datasource = datasource def __iter__(self) -> Iterator[Tuple[T_co, str]]: for d in self.datasource: yield d, str(d) self.assertTrue(issubclass(DP3, IterDataPipe)) dp1 = DP3(range(10)) # type: ignore[assignment] self.assertTrue(DP3.type.issubtype(dp1.type) and dp1.type.issubtype(DP3.type)) dp2 = DP3(5) # type: ignore[assignment] self.assertEqual(dp1.type, dp2.type) class DP4(IterDataPipe[tuple]): r""" DataPipe without __iter__ annotation""" def __iter__(self): raise NotImplementedError self.assertTrue(issubclass(DP4, IterDataPipe)) dp = DP4() self.assertTrue(dp.type.param == tuple) class DP5(IterDataPipe): r""" DataPipe without type annotation""" def __iter__(self) -> Iterator[str]: raise NotImplementedError self.assertTrue(issubclass(DP5, IterDataPipe)) dp = DP5() # type: ignore[assignment] from torch.utils.data._typing import issubtype self.assertTrue(issubtype(dp.type.param, Any) and issubtype(Any, dp.type.param)) class DP6(IterDataPipe[int]): r""" DataPipe with plain Iterator""" def __iter__(self) -> Iterator: raise NotImplementedError self.assertTrue(issubclass(DP6, IterDataPipe)) dp = DP6() # type: ignore[assignment] self.assertTrue(dp.type.param == int) class DP7(IterDataPipe[Awaitable[T_co]]): r""" DataPipe with abstract base class""" self.assertTrue(issubclass(DP6, IterDataPipe)) self.assertTrue(DP7.type.param == Awaitable[T_co]) class DP8(DP7[str]): r""" DataPipe subclass from a DataPipe with abc type""" self.assertTrue(issubclass(DP8, IterDataPipe)) self.assertTrue(DP8.type.param == Awaitable[str])