def testMultipleDispatchers(self): dispatcher = dispatch.PythonAPIDispatcher('tf.foo', ['x', 'y', 'name'], (None, )) rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor) rt_x_checker = dispatch.PySignatureChecker([(0, rt_checker)]) rt_y_checker = dispatch.PySignatureChecker([(1, rt_checker)]) f1 = lambda x, y, name=None: 'f1' f2 = lambda x, y, name=None: 'f2' rt = ragged_factory_ops.constant([[1, 2], [3]]) dispatcher.Register(rt_x_checker, f1) dispatcher.Register(rt_y_checker, f2) self.assertEqual(dispatcher.Dispatch((rt, 5), None), 'f1') self.assertEqual(dispatcher.Dispatch(('foo', rt), None), 'f2') self.assertEqual(dispatcher.Dispatch(('foo', ), {'y': rt}), 'f2') self.assertEqual(dispatcher.Dispatch(('foo', 'bar'), None), NotImplemented) with self.assertRaisesRegex( ValueError, 'Multiple dispatch targets .*' r'match the arguments to tf\.foo'): dispatcher.Dispatch((rt, rt), None)
def testSortByCost(self): a = dispatch.MakeInstanceChecker(int) b = dispatch.MakeInstanceChecker(float) c = dispatch.MakeUnionChecker([a, b]) d = dispatch.MakeListChecker(a) e = dispatch.MakeListChecker(c) checker = dispatch.PySignatureChecker([(0, e), (1, c), (2, d), (3, a)]) # Note: `repr(checker)` lists the args in the order they will be checked. self.assertEqual( repr(checker), '<PySignatureChecker ' 'args[3]:int, ' # a: cost=1 'args[1]:Union[int, float], ' # c: cost=3 'args[2]:List[int], ' # d: cost=10 'args[0]:List[Union[int, float]]>' # e: cost=30 ) # pyformat: disable
def testList(self): rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor) rt_list_checker = dispatch.MakeListChecker(rt_checker) checker = dispatch.PySignatureChecker([(0, rt_list_checker)]) rt = ragged_factory_ops.constant([[1, 2], [3]]) self.check_signatures(checker, [ (([rt], ), True), (([], ), False), ((rt, ), False), (([rt, rt + 3, rt * 2], ), True), (([rt, rt.values, rt * 2], ), False), ]) # pyformat: disable self.assertEqual(repr(checker), '<PySignatureChecker args[0]:List[RaggedTensor]>')
def _make_signature_checker(api_signature, signature): """Builds a PySignatureChecker for the given type signature. Args: api_signature: The `inspect.Signature` of the API whose signature is being checked. signature: Dictionary mapping parameter names to type annotations. Returns: A `PySignatureChecker`. """ if not (isinstance(signature, dict) and all(isinstance(k, (str, int)) for k in signature)): raise TypeError( "signatures must be dictionaries mapping parameter names " "to type annotations.") checkers = [] param_names = list(api_signature.parameters) for param_name, param_type in signature.items(): # Convert positional parameters to named parameters. if (isinstance(param_name, int) and param_name < len(api_signature.parameters)): param_name = list( api_signature.parameters.values())[param_name].name # Check that the parameter exists, and has an appropriate kind. param = api_signature.parameters.get(param_name, None) if param is None: raise ValueError("signature includes annotation for unknown " f"parameter {param_name!r}.") if param.kind not in (tf_inspect.Parameter.POSITIONAL_ONLY, tf_inspect.Parameter.POSITIONAL_OR_KEYWORD): raise ValueError( "Dispatch currently only supports type annotations " "for positional parameters; can't handle annotation " f"for {param.kind!r} parameter {param_name}.") checker = make_type_checker(param_type) index = param_names.index(param_name) checkers.append((index, checker)) return _api_dispatcher.PySignatureChecker(checkers)
def testSimpleSignature(self): int_checker = dispatch.MakeInstanceChecker(int) rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor) checker = dispatch.PySignatureChecker([(0, int_checker), (2, rt_checker)]) rt = ragged_factory_ops.constant([[1, 2], [3]]) self.check_signatures(checker, [ ((1, 2, rt), True), ((1, 2, 3), False), ((1, 2), False), ((), False), ((5, 'x', rt, None), True), (([5], 'x', rt, None), False), ((5, 'x', [rt], None), False), ]) # pyformat: disable self.assertEqual( repr(checker), '<PySignatureChecker args[0]:int, args[2]:RaggedTensor>')
def testUnion(self): rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor) tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor) rt_or_tensor = dispatch.MakeUnionChecker([rt_checker, tensor_checker]) checker = dispatch.PySignatureChecker([(0, rt_or_tensor), (1, rt_or_tensor)]) t = constant_op.constant([[1, 2], [3, 4]]) rt = ragged_factory_ops.constant([[1, 2], [3]]) self.check_signatures(checker, [ ((t, t), False), ((t, rt), True), ((rt, t), True), ((rt, rt), True), ((rt, [rt]), False), ((rt, rt, 1, 2, None), True), ]) # pyformat: disable self.assertEqual( repr(checker), '<PySignatureChecker args[0]:Union[RaggedTensor, Tensor], ' 'args[1]:Union[RaggedTensor, Tensor]>')
def testListAndUnionDispatch(self): dispatcher = dispatch.PythonAPIDispatcher('tf.foo', ['x', 'ys', 'name'], (None, )) rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor) tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor) rt_or_t = dispatch.MakeUnionChecker([rt_checker, tensor_checker]) list_of_rt_or_t = dispatch.MakeListChecker(rt_or_t) f1 = lambda x, ys, name=None: 'f1' dispatcher.Register( dispatch.PySignatureChecker([(0, rt_or_t), (1, list_of_rt_or_t)]), f1) rt = ragged_factory_ops.constant([[1, 2], [3]]) t = constant_op.constant(5) self.assertEqual(dispatcher.Dispatch((rt, [t]), None), 'f1') self.assertEqual(dispatcher.Dispatch((rt, [rt]), None), 'f1') self.assertEqual(dispatcher.Dispatch((t, [rt]), None), 'f1') self.assertEqual(dispatcher.Dispatch((rt, []), None), 'f1') self.assertEqual(dispatcher.Dispatch((t, [t, t, rt, t]), None), 'f1') self.assertEqual(dispatcher.Dispatch((rt, [t], 'my_name'), None), 'f1') self.assertEqual(dispatcher.Dispatch((), {'x': rt, 'ys': [t]}), 'f1') self.assertEqual( dispatcher.Dispatch((), { 'x': rt, 'ys': [t], 'name': 'x' }), 'f1') self.assertEqual(dispatcher.Dispatch((t, [t]), None), NotImplemented) self.assertEqual(dispatcher.Dispatch((t, []), None), NotImplemented) self.assertEqual(dispatcher.Dispatch(('foo', [rt]), None), NotImplemented) self.assertEqual(dispatcher.Dispatch(('foo', 'bar'), None), NotImplemented) self.assertEqual(dispatcher.Dispatch(('foo', 'bar', 'baz'), None), NotImplemented)
def testBasicDispatch(self): dispatcher = dispatch.PythonAPIDispatcher('tf.foo', ['x', 'y', 'name'], (None, )) rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor) f1 = lambda x, y, name=None: 'f1' dispatcher.Register(dispatch.PySignatureChecker([(0, rt_checker)]), f1) rt = ragged_factory_ops.constant([[1, 2], [3]]) self.assertEqual(dispatcher.Dispatch((rt, 5), None), 'f1') self.assertEqual(dispatcher.Dispatch((rt, 5, 'my_name'), None), 'f1') self.assertEqual(dispatcher.Dispatch((), {'x': rt, 'y': 5}), 'f1') self.assertEqual( dispatcher.Dispatch((), { 'x': rt, 'y': 5, 'name': 'x' }), 'f1') self.assertEqual(dispatcher.Dispatch(('foo', rt), None), NotImplemented) self.assertEqual(dispatcher.Dispatch(('foo', 'bar'), None), NotImplemented) self.assertEqual(dispatcher.Dispatch(('foo', 'bar', 'baz'), None), NotImplemented)