def test_get_signature_with_class_property(self): class C: @property def x(self): return 99 c = C() with self.assertRaises(TypeError): function_utils.get_signature(c.x)
def test_get_callargs_for_signature(self, fn, args, kwargs): signature = function_utils.get_signature(fn) expected_error = None try: signature = inspect.signature(fn) bound_arguments = signature.bind(*args, **kwargs) expected_callargs = bound_arguments.arguments except TypeError as e: expected_error = e expected_callargs = None result_callargs = None if expected_error is None: try: bound_args = signature.bind(*args, **kwargs).arguments self.assertEqual(bound_args, expected_callargs) except (TypeError, AssertionError) as test_err: raise AssertionError( 'With signature `{!s}`, args {!s}, kwargs {!s}, expected bound ' 'args {!s} and error {!s}, tested function returned {!s} and the ' 'test has failed with message: {!s}'.format( signature, args, kwargs, expected_callargs, expected_error, result_callargs, test_err)) else: with self.assertRaises(TypeError): _ = signature.bind(*args, **kwargs)
def test_as_wrapper_with_classmethod(self): class C: @classmethod def foo(cls, x): return x * 2 signature = function_utils.get_signature(C.foo) self.assertEqual( signature.parameters, collections.OrderedDict(x=inspect.Parameter( 'x', inspect.Parameter.POSITIONAL_OR_KEYWORD)))
def test_get_defun_argspec_with_untyped_non_eager_defun(self): # In a tf.function with no input signature, the same restrictions as in a # typed eager function apply. fn = tf.function(lambda x, y, *z: None) self.assertEqual( collections.OrderedDict(function_utils.get_signature(fn).parameters), collections.OrderedDict( x=inspect.Parameter('x', inspect.Parameter.POSITIONAL_OR_KEYWORD), y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD), z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL), ))
def test_get_signature_with_class_instance_method(self): class C: def __init__(self, x): self._x = x def foo(self, y): return self._x * y c = C(5) signature = function_utils.get_signature(c.foo) self.assertEqual( signature.parameters, collections.OrderedDict(y=inspect.Parameter( 'y', inspect.Parameter.POSITIONAL_OR_KEYWORD)))
def test_get_defun_argspec_with_typed_non_eager_defun(self): # In a tf.function with a defined input signature, **kwargs or default # values are not allowed, but *args are, and the input signature may overlap # with *args. fn = tf.function(lambda x, y, *z: None, ( tf.TensorSpec(None, tf.int32), tf.TensorSpec(None, tf.bool), tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32), )) self.assertEqual( collections.OrderedDict(function_utils.get_signature(fn).parameters), collections.OrderedDict( x=inspect.Parameter('x', inspect.Parameter.POSITIONAL_OR_KEYWORD), y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD), z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL), ))
class FunctionUtilsTest(test_case.TestCase, parameterized.TestCase): def test_get_defun_argspec_with_typed_non_eager_defun(self): # In a tf.function with a defined input signature, **kwargs or default # values are not allowed, but *args are, and the input signature may overlap # with *args. fn = tf.function(lambda x, y, *z: None, ( tf.TensorSpec(None, tf.int32), tf.TensorSpec(None, tf.bool), tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32), )) self.assertEqual( collections.OrderedDict( function_utils.get_signature(fn).parameters), collections.OrderedDict( x=inspect.Parameter('x', inspect.Parameter.POSITIONAL_OR_KEYWORD), y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD), z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL), )) def test_get_defun_argspec_with_untyped_non_eager_defun(self): # In a tf.function with no input signature, the same restrictions as in a # typed eager function apply. fn = tf.function(lambda x, y, *z: None) self.assertEqual( collections.OrderedDict( function_utils.get_signature(fn).parameters), collections.OrderedDict( x=inspect.Parameter('x', inspect.Parameter.POSITIONAL_OR_KEYWORD), y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD), z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL), )) def test_get_signature_with_class_instance_method(self): class C: def __init__(self, x): self._x = x def foo(self, y): return self._x * y c = C(5) signature = function_utils.get_signature(c.foo) self.assertEqual( signature.parameters, collections.OrderedDict(y=inspect.Parameter( 'y', inspect.Parameter.POSITIONAL_OR_KEYWORD))) def test_get_signature_with_class_property(self): class C: @property def x(self): return 99 c = C() with self.assertRaises(TypeError): function_utils.get_signature(c.x) def test_as_wrapper_with_classmethod(self): class C: @classmethod def foo(cls, x): return x * 2 signature = function_utils.get_signature(C.foo) self.assertEqual( signature.parameters, collections.OrderedDict(x=inspect.Parameter( 'x', inspect.Parameter.POSITIONAL_OR_KEYWORD))) # pyformat: disable @parameterized.parameters( itertools.product( # Values of 'fn' to test. [ lambda: None, lambda a: None, lambda a, b: None, lambda *a: None, lambda **a: None, lambda *a, **b: None, lambda a, *b: None, lambda a, **b: None, lambda a, b, **c: None, lambda a, b=10: None, lambda a, b=10, c=20: None, lambda a, b=10, *c: None, lambda a, b=10, **c: None, lambda a, b=10, *c, **d: None, lambda a, b, c=10, *d: None, lambda a=10, b=20, c=30, **d: None ], # Values of 'args' to test. [[], [1], [1, 2], [1, 2, 3], [1, 2, 3, 4]], # Values of 'kwargs' to test. [{}, { 'b': 100 }, { 'name': 'foo' }, { 'b': 100, 'name': 'foo' }])) # pyformat: enable def test_get_callargs_for_signature(self, fn, args, kwargs): signature = function_utils.get_signature(fn) expected_error = None try: signature = inspect.signature(fn) bound_arguments = signature.bind(*args, **kwargs) expected_callargs = bound_arguments.arguments except TypeError as e: expected_error = e expected_callargs = None result_callargs = None if expected_error is None: try: bound_args = signature.bind(*args, **kwargs).arguments self.assertEqual(bound_args, expected_callargs) except (TypeError, AssertionError) as test_err: raise AssertionError( 'With signature `{!s}`, args {!s}, kwargs {!s}, expected bound ' 'args {!s} and error {!s}, tested function returned {!s} and the ' 'test has failed with message: {!s}'.format( signature, args, kwargs, expected_callargs, expected_error, result_callargs, test_err)) else: with self.assertRaises(TypeError): _ = signature.bind(*args, **kwargs) # pyformat: disable @parameterized.named_parameters( ('args_only', function_utils.get_signature(lambda a: None), [tf.int32], collections.OrderedDict()), ('args_and_kwargs_unnamed', function_utils.get_signature(lambda a, b=True: None), [tf.int32, tf.bool], collections.OrderedDict()), ('args_and_kwargs_named', function_utils.get_signature(lambda a, b=True: None), [tf.int32], collections.OrderedDict(b=tf.bool)), ('args_and_kwargs_default_int', function_utils.get_signature(lambda a=10, b=True: None), [tf.int32], collections.OrderedDict(b=tf.bool)), ) # pyformat: enable def test_is_signature_compatible_with_types_true(self, signature, *args, **kwargs): self.assertFalse( function_utils.is_signature_compatible_with_types( signature, *args, **kwargs)) # pyformat: disable @parameterized.named_parameters( ('args_only', function_utils.get_signature(lambda a=True: None), [tf.int32], collections.OrderedDict()), ('args_and_kwargs', function_utils.get_signature(lambda a=10, b=True: None), [tf.bool], collections.OrderedDict(b=tf.bool)), ) # pyformat: enable def test_is_signature_compatible_with_types_false(self, signature, *args, **kwargs): self.assertFalse( function_utils.is_signature_compatible_with_types( signature, *args, **kwargs)) # pyformat: disable @parameterized.named_parameters( ('int', tf.int32, False), ('tuple_unnamed', [tf.int32, tf.int32], True), ('tuple_partially_named', [tf.int32, ('b', tf.int32)], True), ('tuple_named', [('a', tf.int32), ('b', tf.int32)], True), ('tuple_partially_named_kwargs_first', [ ('a', tf.int32), tf.int32 ], False), ('struct', structure.Struct([(None, 1), ('a', 2)]), True), ('struct_kwargs_first', structure.Struct([('a', 1), (None, 2)]), False)) # pyformat: enable def test_is_argument_struct(self, arg, expected_result): self.assertEqual(function_utils.is_argument_struct(arg), expected_result) # pyformat: disable @parameterized.named_parameters( ('tuple_unnamed', structure.Struct([(None, 1)]), [1], {}), ('tuple_partially_named', structure.Struct([(None, 1), ('a', 2)]), [1], { 'a': 2 }), ) # pyformat: enable def test_unpack_args_from_structure(self, tuple_with_args, expected_args, expected_kwargs): self.assertEqual( function_utils.unpack_args_from_struct(tuple_with_args), (expected_args, expected_kwargs)) # pyformat: disable @parameterized.named_parameters( ('tuple_unnamed_1', [tf.int32], [tf.int32], {}), ('tuple_named_1', [('a', tf.int32)], [], { 'a': tf.int32 }), ('tuple_unnamed_2', [tf.int32, tf.bool], [tf.int32, tf.bool], {}), ('tuple_partially_named', [tf.int32, ('b', tf.bool)], [tf.int32], { 'b': tf.bool }), ('tuple_named_2', [('a', tf.int32), ('b', tf.bool)], [], { 'a': tf.int32, 'b': tf.bool }), ) # pyformat: enable def test_unpack_args_from_struct_type(self, tuple_with_args, expected_args, expected_kwargs): args, kwargs = function_utils.unpack_args_from_struct(tuple_with_args) self.assertEqual(len(args), len(expected_args)) for idx, arg in enumerate(args): self.assertTrue( arg.is_equivalent_to( computation_types.to_type(expected_args[idx]))) self.assertEqual(set(kwargs.keys()), set(expected_kwargs.keys())) for k, v in kwargs.items(): self.assertTrue( v.is_equivalent_to( computation_types.to_type(expected_kwargs[k]))) def test_pack_args_into_struct_without_type_spec(self): self.assertEqual(function_utils.pack_args_into_struct([1], {'a': 10}), structure.Struct([(None, 1), ('a', 10)])) self.assertIn( function_utils.pack_args_into_struct([1, 2], { 'a': 10, 'b': 20 }), [ structure.Struct([ (None, 1), (None, 2), ('a', 10), ('b', 20), ]), structure.Struct([ (None, 1), (None, 2), ('b', 20), ('a', 10), ]) ]) self.assertIn( function_utils.pack_args_into_struct([], { 'a': 10, 'b': 20 }), [ structure.Struct([('a', 10), ('b', 20)]), structure.Struct([('b', 20), ('a', 10)]) ]) self.assertEqual(function_utils.pack_args_into_struct([1], {}), structure.Struct([(None, 1)])) # pyformat: disable @parameterized.named_parameters( ('int', [1], {}, [tf.int32], [(None, 1)]), ('tuple_unnamed_with_args', [1, True], {}, [tf.int32, tf.bool], [ (None, 1), (None, True) ]), ('tuple_named_with_args', [1, True], {}, [ ('x', tf.int32), ('y', tf.bool) ], [('x', 1), ('y', True)]), ('tuple_named_with_args_and_kwargs', [1], { 'y': True }, [('x', tf.int32), ('y', tf.bool)], [('x', 1), ('y', True)]), ('tuple_with_kwargs', [], { 'x': 1, 'y': True }, [('x', tf.int32), ('y', tf.bool)], [('x', 1), ('y', True)]), ('tuple_with_args_odict', [], collections.OrderedDict([('y', True), ('x', 1)]), [('x', tf.int32), ('y', tf.bool)], [('x', 1), ('y', True)])) # pyformat: enable def test_pack_args_into_struct_with_type_spec_expect_success( self, args, kwargs, type_spec, elements): self.assertEqual( function_utils.pack_args_into_struct(args, kwargs, type_spec, NoopIngestContextForTest()), structure.Struct(elements)) # pyformat: disable @parameterized.named_parameters( ('wrong_type', [1], {}, [(tf.bool)]), ('wrong_structure', [], { 'x': 1, 'y': True }, [(tf.int32), (tf.bool)]), ) # pyformat: enable def test_pack_args_into_struct_with_type_spec_expect_failure( self, args, kwargs, type_spec): with self.assertRaises(TypeError): function_utils.pack_args_into_struct(args, kwargs, type_spec, NoopIngestContextForTest()) # pyformat: disable @parameterized.named_parameters( ('none', None, [], {}, 'None'), ('int', tf.int32, [1], {}, '1'), ('tuple_unnamed', [tf.int32, tf.bool], [1, True], {}, '<1,True>'), ('tuple_named_with_args', [('x', tf.int32), ('y', tf.bool)], [1, True], {}, '<x=1,y=True>'), ('tuple_named_with_kwargs', [('x', tf.int32), ('y', tf.bool)], [1], { 'y': True }, '<x=1,y=True>'), ('tuple_with_args_struct', [tf.int32, tf.bool], [structure.Struct([(None, 1), (None, True)])], {}, '<1,True>')) # pyformat: enable def test_pack_args(self, parameter_type, args, kwargs, expected_value_string): self.assertEqual( str( function_utils.pack_args(parameter_type, args, kwargs, NoopIngestContextForTest())), expected_value_string) # pyformat: disable @parameterized.named_parameters( ('const', lambda: 10, None, None, None, 10), ('add_const', lambda x=1: x + 10, None, None, None, 11), ('add_const_with_type', lambda x=1: x + 10, tf.int32, None, 20, 30), ('add', lambda x, y: x + y, [tf.int32, tf.int32], None, structure.Struct([('x', 5), ('y', 6)]), 11), ('str_tuple', lambda *args: str(args), [tf.int32, tf.int32], True, structure.Struct([(None, 5), (None, 6)]), '(5, 6)'), ('str_tuple_with_named_type', lambda *args: str(args), [ ('x', tf.int32), ('y', tf.int32) ], False, structure.Struct([('x', 5), ('y', 6) ]), '(Struct([(\'x\', 5), (\'y\', 6)]),)'), ( 'str_ing', lambda x: str(x), # pylint: disable=unnecessary-lambda [tf.int32], None, structure.Struct([(None, 10)]), '[10]'), ) # pyformat: enable def test_wrap_as_zero_or_one_arg_callable(self, fn, parameter_type, unpack, arg, expected_result): parameter_type = computation_types.to_type(parameter_type) unpack_arguments = function_utils.create_argument_unpacking_fn( fn, parameter_type, unpack) args, kwargs = unpack_arguments(arg) actual_result = fn(*args, **kwargs) self.assertEqual(actual_result, expected_result)
def _parameters(fn): return function_utils.get_signature(fn).parameters.values()