def testGetArgSpecOnPartialWithVarkwargs(self): """Tests getargspec on partial function with variable keyword arguments.""" def func(m, n, **kwarg): return m * n + len(kwarg) partial_func = functools.partial(func, 7) argspec = tf_inspect.ArgSpec(args=['n'], varargs=None, keywords='kwarg', defaults=None) self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
def testGetArgSpecOnCallableObject(self): class Callable(object): def __call__(self, a, b=1, c='hello'): pass argspec = tf_inspect.ArgSpec(args=['self', 'a', 'b', 'c'], varargs=None, keywords=None, defaults=(1, 'hello')) test_obj = Callable() self.assertEqual(argspec, tf_inspect.getargspec(test_obj))
def testGetArgSpecOnPartialKeywordArgument(self): """Tests getargspec on partial function that prunes some arguments.""" def func(m, n): return 2 * m + n partial_func = functools.partial(func, n=7) argspec = tf_inspect.ArgSpec(args=['m', 'n'], varargs=None, keywords=None, defaults=(7, )) self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
def testGetArgSpecOnPartialKeywordArgumentWithDefaultValue(self): """Tests getargspec on partial function that prunes argument by keyword.""" def func(m=1, n=2): return 2 * m + n partial_func = functools.partial(func, n=7) argspec = tf_inspect.ArgSpec(args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7)) self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
def testGetArgSpecOnPartialNoArgumentsLeft(self): """Tests getargspec on partial function that prunes all arguments.""" def func(m, n): return 2 * m + n partial_func = functools.partial(func, 7, 10) argspec = tf_inspect.ArgSpec(args=[], varargs=None, keywords=None, defaults=None) self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
def testGetArgSpecOnPartialValidArgspec(self): """Tests getargspec on partial function with valid argspec.""" def func(m, n, l, k=4): return 2 * m + l + n * k partial_func = functools.partial(func, n=7, l=2) argspec = tf_inspect.ArgSpec(args=['m', 'n', 'l', 'k'], varargs=None, keywords=None, defaults=(7, 2, 4)) self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
def testGetArgSpecOnPartialPositionalArgumentOnly(self): """Tests getargspec on partial function with only positional arguments.""" def func(m, n): return 2 * m + n partial_func = functools.partial(func, 7) argspec = tf_inspect.ArgSpec(args=['n'], varargs=None, keywords=None, defaults=None) self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
def testGetArgSpecOnPartialWithDecorator(self): """Tests getargspec on decorated partial function.""" @test_decorator('decorator') def func(m=1, n=2): return 2 * m + n partial_func = functools.partial(func, n=7) argspec = tf_inspect.ArgSpec( args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7)) self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
def test_argspec_for_functools_partial(self): # pylint: disable=unused-argument def test_function_for_partial1(arg1, arg2, kwarg1=1, kwarg2=2): pass def test_function_for_partial2(arg1, arg2, *my_args, **my_kwargs): pass # pylint: enable=unused-argument # pylint: disable=protected-access # Make sure everything works for regular functions. expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1', 'kwarg2'], None, None, (1, 2)) self.assertEqual(expected, parser._get_arg_spec(test_function_for_partial1)) # Make sure doing nothing works. expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1', 'kwarg2'], None, None, (1, 2)) partial = functools.partial(test_function_for_partial1) self.assertEqual(expected, parser._get_arg_spec(partial)) # Make sure setting args from the front works. expected = tf_inspect.ArgSpec(['arg2', 'kwarg1', 'kwarg2'], None, None, (1, 2)) partial = functools.partial(test_function_for_partial1, 1) self.assertEqual(expected, parser._get_arg_spec(partial)) expected = tf_inspect.ArgSpec(['kwarg2',], None, None, (2,)) partial = functools.partial(test_function_for_partial1, 1, 2, 3) self.assertEqual(expected, parser._get_arg_spec(partial)) # Make sure setting kwargs works. expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg2'], None, None, (2,)) partial = functools.partial(test_function_for_partial1, kwarg1=0) self.assertEqual(expected, parser._get_arg_spec(partial)) expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1'], None, None, (1,)) partial = functools.partial(test_function_for_partial1, kwarg2=0) self.assertEqual(expected, parser._get_arg_spec(partial)) expected = tf_inspect.ArgSpec(['arg1'], None, None, ()) partial = functools.partial(test_function_for_partial1, arg2=0, kwarg1=0, kwarg2=0) self.assertEqual(expected, parser._get_arg_spec(partial)) # Make sure *args, *kwargs is accounted for. expected = tf_inspect.ArgSpec([], 'my_args', 'my_kwargs', ()) partial = functools.partial(test_function_for_partial2, 0, 1) self.assertEqual(expected, parser._get_arg_spec(partial))
def testGetArgSpecOnNewClass(self): class NewClass(object): def __new__(cls, a, b=1, c='hello'): pass argspec = tf_inspect.ArgSpec( args=['cls', 'a', 'b', 'c'], varargs=None, keywords=None, defaults=(1, 'hello')) self.assertEqual(argspec, tf_inspect.getargspec(NewClass))
def testGetArgSpecOnInitClass(self): class InitClass(object): def __init__(self, a, b=1, c='hello'): pass argspec = tf_inspect.ArgSpec( args=['self', 'a', 'b', 'c'], varargs=None, keywords=None, defaults=(1, 'hello')) self.assertEqual(argspec, tf_inspect.getargspec(InitClass))
def _get_arg_spec(func): """Extracts signature information from a function or functools.partial object. For functions, uses `tf_inspect.getargspec`. For `functools.partial` objects, corrects the signature of the underlying function to take into account the removed arguments. Args: func: A function whose signature to extract. Returns: An `ArgSpec` namedtuple `(args, varargs, keywords, defaults)`, as returned by `tf_inspect.getargspec`. """ # getargspec does not work for functools.partial objects directly. if isinstance(func, functools.partial): argspec = tf_inspect.getargspec(func.func) # Remove the args from the original function that have been used up. first_default_arg = (len(argspec.args or []) - len(argspec.defaults or [])) partial_args = len(func.args) argspec_args = [] if argspec.args: argspec_args = list(argspec.args[partial_args:]) argspec_defaults = list(argspec.defaults or ()) if argspec.defaults and partial_args > first_default_arg: argspec_defaults = list(argspec.defaults[partial_args - first_default_arg:]) first_default_arg = max(0, first_default_arg - partial_args) for kwarg in (func.keywords or []): if kwarg in (argspec.args or []): i = argspec_args.index(kwarg) argspec_args.pop(i) if i >= first_default_arg: argspec_defaults.pop(i - first_default_arg) else: first_default_arg -= 1 return tf_inspect.ArgSpec(args=argspec_args, varargs=argspec.varargs, keywords=argspec.keywords, defaults=tuple(argspec_defaults)) else: # Regular function or method, getargspec will work fine. return tf_inspect.getargspec(func)
def testUsesOutermostDecoratorsArgSpec(self): def func(): pass def wrapper(*args, **kwargs): return func(*args, **kwargs) decorated = tf_decorator.make_decorator( func, wrapper, decorator_argspec=tf_inspect.ArgSpec(args=['a', 'b', 'c'], varargs=None, keywords=None, defaults=(3, 'hello'))) self.assertEqual({ 'a': 4, 'b': 3, 'c': 'goodbye' }, tf_inspect.getcallargs(decorated, 4, c='goodbye'))