def testGetArgSpecOnPartialWithVarkwargs(self):
        """Tests getargspec on partial function with variable keyword arguments."""
        def func(m, n, **kwarg):
            return m * n + len(kwarg)

        partial_func = functools.partial(func, 7)
        argspec = tf_inspect.ArgSpec(args=['n'],
                                     varargs=None,
                                     keywords='kwarg',
                                     defaults=None)

        self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
    def testGetArgSpecOnCallableObject(self):
        class Callable(object):
            def __call__(self, a, b=1, c='hello'):
                pass

        argspec = tf_inspect.ArgSpec(args=['self', 'a', 'b', 'c'],
                                     varargs=None,
                                     keywords=None,
                                     defaults=(1, 'hello'))

        test_obj = Callable()
        self.assertEqual(argspec, tf_inspect.getargspec(test_obj))
    def testGetArgSpecOnPartialKeywordArgument(self):
        """Tests getargspec on partial function that prunes some arguments."""
        def func(m, n):
            return 2 * m + n

        partial_func = functools.partial(func, n=7)
        argspec = tf_inspect.ArgSpec(args=['m', 'n'],
                                     varargs=None,
                                     keywords=None,
                                     defaults=(7, ))

        self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
    def testGetArgSpecOnPartialKeywordArgumentWithDefaultValue(self):
        """Tests getargspec on partial function that prunes argument by keyword."""
        def func(m=1, n=2):
            return 2 * m + n

        partial_func = functools.partial(func, n=7)
        argspec = tf_inspect.ArgSpec(args=['m', 'n'],
                                     varargs=None,
                                     keywords=None,
                                     defaults=(1, 7))

        self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
    def testGetArgSpecOnPartialNoArgumentsLeft(self):
        """Tests getargspec on partial function that prunes all arguments."""
        def func(m, n):
            return 2 * m + n

        partial_func = functools.partial(func, 7, 10)
        argspec = tf_inspect.ArgSpec(args=[],
                                     varargs=None,
                                     keywords=None,
                                     defaults=None)

        self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
    def testGetArgSpecOnPartialValidArgspec(self):
        """Tests getargspec on partial function with valid argspec."""
        def func(m, n, l, k=4):
            return 2 * m + l + n * k

        partial_func = functools.partial(func, n=7, l=2)
        argspec = tf_inspect.ArgSpec(args=['m', 'n', 'l', 'k'],
                                     varargs=None,
                                     keywords=None,
                                     defaults=(7, 2, 4))

        self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
    def testGetArgSpecOnPartialPositionalArgumentOnly(self):
        """Tests getargspec on partial function with only positional arguments."""
        def func(m, n):
            return 2 * m + n

        partial_func = functools.partial(func, 7)
        argspec = tf_inspect.ArgSpec(args=['n'],
                                     varargs=None,
                                     keywords=None,
                                     defaults=None)

        self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
Example #8
0
  def testGetArgSpecOnPartialWithDecorator(self):
    """Tests getargspec on decorated partial function."""

    @test_decorator('decorator')
    def func(m=1, n=2):
      return 2 * m + n

    partial_func = functools.partial(func, n=7)
    argspec = tf_inspect.ArgSpec(
        args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7))

    self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
Example #9
0
  def test_argspec_for_functools_partial(self):

    # pylint: disable=unused-argument
    def test_function_for_partial1(arg1, arg2, kwarg1=1, kwarg2=2):
      pass

    def test_function_for_partial2(arg1, arg2, *my_args, **my_kwargs):
      pass
    # pylint: enable=unused-argument

    # pylint: disable=protected-access
    # Make sure everything works for regular functions.
    expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1', 'kwarg2'], None,
                                  None, (1, 2))
    self.assertEqual(expected, parser._get_arg_spec(test_function_for_partial1))

    # Make sure doing nothing works.
    expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1', 'kwarg2'], None,
                                  None, (1, 2))
    partial = functools.partial(test_function_for_partial1)
    self.assertEqual(expected, parser._get_arg_spec(partial))

    # Make sure setting args from the front works.
    expected = tf_inspect.ArgSpec(['arg2', 'kwarg1', 'kwarg2'], None, None,
                                  (1, 2))
    partial = functools.partial(test_function_for_partial1, 1)
    self.assertEqual(expected, parser._get_arg_spec(partial))

    expected = tf_inspect.ArgSpec(['kwarg2',], None, None, (2,))
    partial = functools.partial(test_function_for_partial1, 1, 2, 3)
    self.assertEqual(expected, parser._get_arg_spec(partial))

    # Make sure setting kwargs works.
    expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg2'], None, None, (2,))
    partial = functools.partial(test_function_for_partial1, kwarg1=0)
    self.assertEqual(expected, parser._get_arg_spec(partial))

    expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1'], None, None, (1,))
    partial = functools.partial(test_function_for_partial1, kwarg2=0)
    self.assertEqual(expected, parser._get_arg_spec(partial))

    expected = tf_inspect.ArgSpec(['arg1'], None, None, ())
    partial = functools.partial(test_function_for_partial1,
                                arg2=0, kwarg1=0, kwarg2=0)
    self.assertEqual(expected, parser._get_arg_spec(partial))

    # Make sure *args, *kwargs is accounted for.
    expected = tf_inspect.ArgSpec([], 'my_args', 'my_kwargs', ())
    partial = functools.partial(test_function_for_partial2, 0, 1)
    self.assertEqual(expected, parser._get_arg_spec(partial))
Example #10
0
  def testGetArgSpecOnNewClass(self):

    class NewClass(object):

      def __new__(cls, a, b=1, c='hello'):
        pass

    argspec = tf_inspect.ArgSpec(
        args=['cls', 'a', 'b', 'c'],
        varargs=None,
        keywords=None,
        defaults=(1, 'hello'))

    self.assertEqual(argspec, tf_inspect.getargspec(NewClass))
Example #11
0
  def testGetArgSpecOnInitClass(self):

    class InitClass(object):

      def __init__(self, a, b=1, c='hello'):
        pass

    argspec = tf_inspect.ArgSpec(
        args=['self', 'a', 'b', 'c'],
        varargs=None,
        keywords=None,
        defaults=(1, 'hello'))

    self.assertEqual(argspec, tf_inspect.getargspec(InitClass))
Example #12
0
def _get_arg_spec(func):
    """Extracts signature information from a function or functools.partial object.

  For functions, uses `tf_inspect.getargspec`. For `functools.partial` objects,
  corrects the signature of the underlying function to take into account the
  removed arguments.

  Args:
    func: A function whose signature to extract.

  Returns:
    An `ArgSpec` namedtuple `(args, varargs, keywords, defaults)`, as returned
    by `tf_inspect.getargspec`.
  """
    # getargspec does not work for functools.partial objects directly.
    if isinstance(func, functools.partial):
        argspec = tf_inspect.getargspec(func.func)
        # Remove the args from the original function that have been used up.
        first_default_arg = (len(argspec.args or []) -
                             len(argspec.defaults or []))
        partial_args = len(func.args)
        argspec_args = []

        if argspec.args:
            argspec_args = list(argspec.args[partial_args:])

        argspec_defaults = list(argspec.defaults or ())
        if argspec.defaults and partial_args > first_default_arg:
            argspec_defaults = list(argspec.defaults[partial_args -
                                                     first_default_arg:])

        first_default_arg = max(0, first_default_arg - partial_args)
        for kwarg in (func.keywords or []):
            if kwarg in (argspec.args or []):
                i = argspec_args.index(kwarg)
                argspec_args.pop(i)
                if i >= first_default_arg:
                    argspec_defaults.pop(i - first_default_arg)
                else:
                    first_default_arg -= 1
        return tf_inspect.ArgSpec(args=argspec_args,
                                  varargs=argspec.varargs,
                                  keywords=argspec.keywords,
                                  defaults=tuple(argspec_defaults))
    else:  # Regular function or method, getargspec will work fine.
        return tf_inspect.getargspec(func)
    def testUsesOutermostDecoratorsArgSpec(self):
        def func():
            pass

        def wrapper(*args, **kwargs):
            return func(*args, **kwargs)

        decorated = tf_decorator.make_decorator(
            func,
            wrapper,
            decorator_argspec=tf_inspect.ArgSpec(args=['a', 'b', 'c'],
                                                 varargs=None,
                                                 keywords=None,
                                                 defaults=(3, 'hello')))

        self.assertEqual({
            'a': 4,
            'b': 3,
            'c': 'goodbye'
        }, tf_inspect.getcallargs(decorated, 4, c='goodbye'))