def make_type_checker(annotation): """Builds a PyTypeChecker for the given type annotation.""" if type_annotations.is_generic_union(annotation): type_args = type_annotations.get_generic_type_args(annotation) options = [make_type_checker(t) for t in type_args] return _api_dispatcher.MakeUnionChecker(options) elif type_annotations.is_generic_list(annotation): type_args = type_annotations.get_generic_type_args(annotation) if len(type_args) != 1: raise AssertionError( "Expected List[...] to have a single type parameter") elt_type = make_type_checker(type_args[0]) return _api_dispatcher.MakeListChecker(elt_type) elif isinstance(annotation, type): if annotation not in _is_instance_checker_cache: checker = _api_dispatcher.MakeInstanceChecker(annotation) _is_instance_checker_cache[annotation] = checker return _is_instance_checker_cache[annotation] elif annotation is None: return make_type_checker(type(None)) else: raise ValueError( f"Type annotation {annotation} is not currently supported" " by dispatch. Supported annotations: type objects, " " List[...], and Union[...]")
def testSortByCost(self): a = dispatch.MakeInstanceChecker(int) b = dispatch.MakeInstanceChecker(float) c = dispatch.MakeUnionChecker([a, b]) d = dispatch.MakeListChecker(a) e = dispatch.MakeListChecker(c) checker = dispatch.PySignatureChecker([(0, e), (1, c), (2, d), (3, a)]) # Note: `repr(checker)` lists the args in the order they will be checked. self.assertEqual( repr(checker), '<PySignatureChecker ' 'args[3]:int, ' # a: cost=1 'args[1]:Union[int, float], ' # c: cost=3 'args[2]:List[int], ' # d: cost=10 'args[0]:List[Union[int, float]]>' # e: cost=30 ) # pyformat: disable
def testList(self): rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor) rt_list_checker = dispatch.MakeListChecker(rt_checker) checker = dispatch.PySignatureChecker([(0, rt_list_checker)]) rt = ragged_factory_ops.constant([[1, 2], [3]]) self.check_signatures(checker, [ (([rt], ), True), (([], ), False), ((rt, ), False), (([rt, rt + 3, rt * 2], ), True), (([rt, rt.values, rt * 2], ), False), ]) # pyformat: disable self.assertEqual(repr(checker), '<PySignatureChecker args[0]:List[RaggedTensor]>')
def make_type_checker(annotation): """Builds a PyTypeChecker for the given type annotation.""" if type_annotations.is_generic_union(annotation): type_args = type_annotations.get_generic_type_args(annotation) # If the union contains two or more simple types, then use a single # InstanceChecker to check them. simple_types = [t for t in type_args if isinstance(t, type)] simple_types = tuple(sorted(simple_types, key=id)) if len(simple_types) > 1: if simple_types not in _is_instance_checker_cache: checker = _api_dispatcher.MakeInstanceChecker(*simple_types) _is_instance_checker_cache[simple_types] = checker options = ([_is_instance_checker_cache[simple_types]] + [ make_type_checker(t) for t in type_args if not isinstance(t, type) ]) return _api_dispatcher.MakeUnionChecker(options) options = [make_type_checker(t) for t in type_args] return _api_dispatcher.MakeUnionChecker(options) elif type_annotations.is_generic_list(annotation): type_args = type_annotations.get_generic_type_args(annotation) if len(type_args) != 1: raise AssertionError( "Expected List[...] to have a single type parameter") elt_type = make_type_checker(type_args[0]) return _api_dispatcher.MakeListChecker(elt_type) elif isinstance(annotation, type): if annotation not in _is_instance_checker_cache: checker = _api_dispatcher.MakeInstanceChecker(annotation) _is_instance_checker_cache[annotation] = checker return _is_instance_checker_cache[annotation] elif annotation is None: return make_type_checker(type(None)) else: raise ValueError( f"Type annotation {annotation} is not currently supported" " by dispatch. Supported annotations: type objects, " " List[...], and Union[...]")
def testListAndUnionDispatch(self): dispatcher = dispatch.PythonAPIDispatcher('tf.foo', ['x', 'ys', 'name'], (None, )) rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor) tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor) rt_or_t = dispatch.MakeUnionChecker([rt_checker, tensor_checker]) list_of_rt_or_t = dispatch.MakeListChecker(rt_or_t) f1 = lambda x, ys, name=None: 'f1' dispatcher.Register( dispatch.PySignatureChecker([(0, rt_or_t), (1, list_of_rt_or_t)]), f1) rt = ragged_factory_ops.constant([[1, 2], [3]]) t = constant_op.constant(5) self.assertEqual(dispatcher.Dispatch((rt, [t]), None), 'f1') self.assertEqual(dispatcher.Dispatch((rt, [rt]), None), 'f1') self.assertEqual(dispatcher.Dispatch((t, [rt]), None), 'f1') self.assertEqual(dispatcher.Dispatch((rt, []), None), 'f1') self.assertEqual(dispatcher.Dispatch((t, [t, t, rt, t]), None), 'f1') self.assertEqual(dispatcher.Dispatch((rt, [t], 'my_name'), None), 'f1') self.assertEqual(dispatcher.Dispatch((), {'x': rt, 'ys': [t]}), 'f1') self.assertEqual( dispatcher.Dispatch((), { 'x': rt, 'ys': [t], 'name': 'x' }), 'f1') self.assertEqual(dispatcher.Dispatch((t, [t]), None), NotImplemented) self.assertEqual(dispatcher.Dispatch((t, []), None), NotImplemented) self.assertEqual(dispatcher.Dispatch(('foo', [rt]), None), NotImplemented) self.assertEqual(dispatcher.Dispatch(('foo', 'bar'), None), NotImplemented) self.assertEqual(dispatcher.Dispatch(('foo', 'bar', 'baz'), None), NotImplemented)
def testListChecker(self): int_checker = dispatch.MakeInstanceChecker(int) tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor) ragged_checker = dispatch.MakeInstanceChecker( ragged_tensor.RaggedTensor) np_int_checker = dispatch.MakeInstanceChecker(np.integer) t = constant_op.constant([1, 2, 3]) rt = ragged_factory_ops.constant([[1, 2], [3, 4, 5]]) a = [1, 2, 3] b = ['a', 2, t] c = [t, t * 2, t - 2] d = [t, rt] e = [] f = (1, 2, 3) g = (rt, ) h = {1: 2, 3: 4} i = np.array([1, 2, 3]) with self.subTest('List[int]'): checker = dispatch.MakeListChecker(int_checker) self.assertEqual(checker.Check(a), MATCH) self.assertEqual(checker.Check(b), NO_MATCH) self.assertEqual(checker.Check(c), NO_MATCH) self.assertEqual(checker.Check(d), NO_MATCH) self.assertEqual(checker.Check(e), MATCH) self.assertEqual(checker.Check(f), MATCH) self.assertEqual(checker.Check(iter(a)), NO_MATCH) self.assertEqual(checker.Check(iter(b)), NO_MATCH) self.assertEqual(checker.Check(reversed(e)), NO_MATCH) self.assertEqual(checker.Check(h), NO_MATCH) self.assertEqual(checker.Check(i), NO_MATCH) self.assertEqual(checker.cost(), 10) self.assertEqual(repr(checker), '<PyTypeChecker List[int]>') with self.subTest('List[Tensor]'): checker = dispatch.MakeListChecker(tensor_checker) self.assertEqual(checker.Check(a), NO_MATCH) self.assertEqual(checker.Check(b), NO_MATCH) self.assertEqual(checker.Check(c), MATCH) self.assertEqual(checker.Check(d), NO_MATCH) self.assertEqual(checker.Check(e), MATCH) self.assertEqual(checker.cost(), 10) self.assertEqual(repr(checker), '<PyTypeChecker List[Tensor]>') with self.subTest('List[Union[Tensor, RaggedTensor]]'): checker = dispatch.MakeListChecker( dispatch.MakeUnionChecker([tensor_checker, ragged_checker])) self.assertEqual(checker.Check(a), NO_MATCH) self.assertEqual(checker.Check(b), NO_MATCH) self.assertEqual(checker.Check(c), MATCH) self.assertEqual(checker.Check(d), MATCH_DISPATCHABLE) self.assertEqual(checker.Check(e), MATCH) self.assertEqual(checker.Check(f), NO_MATCH) self.assertEqual(checker.Check(g), MATCH_DISPATCHABLE) self.assertEqual(checker.cost(), 30) self.assertEqual( repr(checker), '<PyTypeChecker List[Union[Tensor, RaggedTensor]]>') with self.subTest('List[Union[int, np.integer]]'): # Note: np.integer is a subtype of int in *some* Python versions. checker = dispatch.MakeListChecker( dispatch.MakeUnionChecker([int_checker, np_int_checker])) self.assertEqual(checker.Check(a), MATCH) self.assertEqual(checker.Check(np.array(a)), NO_MATCH) self.assertEqual(checker.Check(np.array(a) * 1.5), NO_MATCH)