Example #1
0
  def test_polymorphic_function(self):

    class ContextForTest(context_base.Context):

      def ingest(self, val, type_spec):
        return val

      def invoke(self, comp, arg):
        return 'name={},type={},arg={}'.format(
            comp.name, str(comp.type_signature.parameter), str(arg))

    class TestFunction(function_utils.ConcreteFunction):

      def __init__(self, name, parameter_type):
        self._name = name
        super().__init__(
            computation_types.FunctionType(parameter_type, tf.string),
            context_stack_impl.context_stack)

      @property
      def name(self):
        return self._name

    class TestFunctionFactory(object):

      def __init__(self):
        self._count = 0

      def __call__(self, parameter_type):
        self._count = self._count + 1
        return TestFunction(str(self._count), parameter_type)

    with context_stack_impl.context_stack.install(ContextForTest()):
      fn = function_utils.PolymorphicFunction(TestFunctionFactory())
      self.assertEqual(fn(10), 'name=1,type=<int32>,arg=<10>')
      self.assertEqual(
          fn(20, x=True), 'name=2,type=<int32,x=bool>,arg=<20,x=True>')
      self.assertEqual(fn(True), 'name=3,type=<bool>,arg=<True>')
      self.assertEqual(
          fn(30, x=40), 'name=4,type=<int32,x=int32>,arg=<30,x=40>')
      self.assertEqual(fn(50), 'name=1,type=<int32>,arg=<50>')
      self.assertEqual(
          fn(0, x=False), 'name=2,type=<int32,x=bool>,arg=<0,x=False>')
      self.assertEqual(fn(False), 'name=3,type=<bool>,arg=<False>')
      self.assertEqual(
          fn(60, x=70), 'name=4,type=<int32,x=int32>,arg=<60,x=70>')
    def test_call_returns_result(self):
        class TestContext(context_base.Context):
            def ingest(self, val, type_spec):
                return val

            def invoke(self, comp, arg):
                return 'name={},type={},arg={},unpack={}'.format(
                    comp.name, comp.type_signature.parameter, arg, comp.unpack)

        class TestContextStack(context_stack_base.ContextStack):
            def __init__(self):
                super().__init__()
                self._context = TestContext()

            @property
            def current(self):
                return self._context

            def install(self, ctx):
                del ctx  # Unused
                return self._context

        context_stack = TestContextStack()

        class TestFunction(function_utils.ConcreteFunction):
            def __init__(self, name, unpack, parameter_type):
                self._name = name
                self._unpack = unpack
                type_signature = computation_types.FunctionType(
                    parameter_type, tf.string)
                super().__init__(type_signature, context_stack)

            @property
            def name(self):
                return self._name

            @property
            def unpack(self):
                return self._unpack

        class TestFunctionFactory(object):
            def __init__(self):
                self._count = 0

            def __call__(self, parameter_type, unpack):
                self._count = self._count + 1
                return TestFunction(str(self._count), str(unpack),
                                    parameter_type)

        fn = function_utils.PolymorphicFunction(TestFunctionFactory())

        self.assertEqual(fn(10), 'name=1,type=<int32>,arg=<10>,unpack=True')
        self.assertEqual(
            fn(20, x=True),
            'name=2,type=<int32,x=bool>,arg=<20,x=True>,unpack=True')
        fn_with_bool_arg = fn.fn_for_argument_type(
            computation_types.to_type(tf.bool))
        self.assertEqual(fn_with_bool_arg(True),
                         'name=3,type=bool,arg=True,unpack=None')
        self.assertEqual(
            fn(30, x=40),
            'name=4,type=<int32,x=int32>,arg=<30,x=40>,unpack=True')
        self.assertEqual(fn(50), 'name=1,type=<int32>,arg=<50>,unpack=True')
        self.assertEqual(
            fn(0, x=False),
            'name=2,type=<int32,x=bool>,arg=<0,x=False>,unpack=True')
        fn_with_bool_arg = fn.fn_for_argument_type(
            computation_types.to_type(tf.bool))
        self.assertEqual(fn_with_bool_arg(False),
                         'name=3,type=bool,arg=False,unpack=None')
        self.assertEqual(
            fn(60, x=70),
            'name=4,type=<int32,x=int32>,arg=<60,x=70>,unpack=True')
def _wrap(fn, parameter_type, wrapper_fn):
    """Wraps a possibly-polymorphic `fn` in `wrapper_fn`.

  If `parameter_type` is `None` and `fn` takes any arguments (even with default
  values), `fn` is inferred to be polymorphic and won't be passed to
  `wrapper_fn` until invocation time (when concrete parameter types are
  available).

  `wrapper_fn` must accept three positional arguments and one defaulted argument
  `name`:

  * `target_fn`, the Python function to be wrapped.

  * `parameter_type`, the optional type of the computation's
    parameter (an instance of `computation_types.Type`).

  * `unpack`, an argument which will be passed on to
    `function_utils.wrap_as_zero_or_one_arg_callable` when wrapping `target_fn`.
    See that function for details.

  * Optional `name`, the name of the function that is being wrapped (only for
    debugging purposes).

  Args:
    fn: The function or defun to wrap as a computation.
    parameter_type: Optional type of any arguments to `fn`.
    wrapper_fn: The Python callable that performs actual wrapping. The object to
      be returned by this function should be an instance of a
      `ConcreteFunction`.

  Returns:
    Either the result of wrapping (an object that represents the computation),
    or a polymorphic callable that performs wrapping upon invocation based on
    argument types. The returned function still may accept multiple
    arguments (it has not yet had
    `function_uils.wrap_as_zero_or_one_arg_callable` applied to it).

  Raises:
    TypeError: if the arguments are of the wrong types, or the `wrapper_fn`
      constructs something that isn't a ConcreteFunction.
  """
    try:
        fn_name = fn.__name__
    except AttributeError:
        fn_name = None
    signature = function_utils.get_signature(fn)
    parameter_type = computation_types.to_type(parameter_type)
    if parameter_type is None and signature.parameters:
        # There is no TFF type specification, and the function/defun declares
        # parameters. Create a polymorphic template.
        def _wrap_polymorphic(parameter_type: computation_types.Type,
                              unpack: Optional[bool]):
            return wrapper_fn(fn, parameter_type, unpack=unpack, name=fn_name)

        polymorphic_fn = function_utils.PolymorphicFunction(_wrap_polymorphic)

        # When applying a decorator, the __doc__ attribute with the documentation
        # in triple-quotes is not automatically transferred from the function on
        # which it was applied to the wrapped object, so we must transfer it here
        # explicitly.
        polymorphic_fn.__doc__ = getattr(fn, '__doc__', None)
        return polymorphic_fn

    # Either we have a concrete parameter type, or this is no-arg function.
    concrete_fn = wrapper_fn(fn, parameter_type, unpack=None)
    py_typecheck.check_type(concrete_fn, function_utils.ConcreteFunction,
                            'value returned by the wrapper')
    if (concrete_fn.type_signature.parameter is not None
            and not concrete_fn.type_signature.parameter.is_equivalent_to(
                parameter_type)):
        raise TypeError(
            'Expected a concrete function that takes parameter {}, got one '
            'that takes {}.'.format(str(parameter_type),
                                    str(concrete_fn.type_signature.parameter)))
    # When applying a decorator, the __doc__ attribute with the documentation
    # in triple-quotes is not automatically transferred from the function on
    concrete_fn.__doc__ = getattr(fn, '__doc__', None)
    return concrete_fn
    def __call__(self, *args, tff_internal_types=None):
        """Handles the different modes of usage of the decorator/wrapper.

    Args:
      *args: Positional arguments (the decorator at this point does not accept
        keyword arguments, although that might change in the future).
      tff_internal_types: TFF internal usage only. This argument should be
        considered private.

    Returns:
      Either a result of wrapping, or a callable that expects a function,
      method, or a tf.function and performs wrapping on it, depending on
      specific usage pattern.

    Raises:
      TypeError: if the arguments are of the wrong types.
      ValueError: if the function to wrap returns `None`.
    """
        if not args or not is_function(args[0]):
            # If invoked as a decorator, and with an empty argument list as "@xyz()"
            # applied to a function definition, expect the Python function being
            # decorated to be passed in the subsequent call, and potentially create
            # a polymorphic callable. The parameter type is unspecified.
            # Deliberate wrapping with a lambda to prevent the caller from being able
            # to accidentally specify parameter type as a second argument.
            # The tricky partial recursion is needed to inline the logic in the
            # "success" case below.
            if tff_internal_types is not None:
                raise TypeError(f'Expected a function to wrap, found {args}.')
            provided_types = tuple(map(computation_types.to_type, args))
            return functools.partial(self.__call__,
                                     tff_internal_types=provided_types)
        # If the first argument on the list is a Python function, instance method,
        # or a tf.function, this is the one that's being wrapped. This is the case
        # of either a decorator invocation without arguments as "@xyz" applied to
        # a function definition, of an inline invocation as
        # `... = xyz(lambda....).`
        # Any of the following arguments, if present, are the arguments to the
        # wrapper that are to be interpreted as the type specification.
        fn_to_wrap = args[0]
        if not tff_internal_types:
            tff_internal_types = tuple(map(computation_types.to_type,
                                           args[1:]))
        else:
            if len(args) > 1:
                raise TypeError(
                    f'Expected no further arguments, found {args[1:]}.')

        parameter_types = tff_internal_types
        parameters = _parameters(fn_to_wrap)

        # NOTE: many of the properties checked here are only necessary for
        # non-polymorphic computations whose type signatures must be resolved
        # prior to use. However, we continue to enforce these requirements even
        # in the polymorphic case in order to avoid creating an inconsistency.
        _check_parameters(parameters)

        try:
            fn_name = fn_to_wrap.__name__
        except AttributeError:
            fn_name = None

        if (not parameter_types) and parameters:
            # There is no TFF type specification, and the function/tf.function
            # declares parameters. Create a polymorphic template.
            def _polymorphic_wrapper(parameter_type: computation_types.Type,
                                     unpack: Optional[bool]):
                unpack_arguments_fn = function_utils.create_argument_unpacking_fn(
                    fn_to_wrap, parameter_type, unpack=unpack)
                wrapped_fn_generator = _wrap_concrete(fn_name,
                                                      self._wrapper_fn,
                                                      parameter_type)
                args, kwargs = unpack_arguments_fn(next(wrapped_fn_generator))
                result = fn_to_wrap(*args, **kwargs)
                if result is None:
                    raise ComputationReturnedNoneError(fn_to_wrap)
                return wrapped_fn_generator.send(result)

            wrapped_func = function_utils.PolymorphicFunction(
                _polymorphic_wrapper)
        else:
            # Either we have a concrete parameter type, or this is no-arg function.
            parameter_type = _parameter_type(parameters, parameter_types)
            unpack_arguments_fn = function_utils.create_argument_unpacking_fn(
                fn_to_wrap, parameter_type, unpack=None)
            wrapped_fn_generator = _wrap_concrete(fn_name, self._wrapper_fn,
                                                  parameter_type)
            args, kwargs = unpack_arguments_fn(next(wrapped_fn_generator))
            result = fn_to_wrap(*args, **kwargs)
            if result is None:
                raise ComputationReturnedNoneError(fn_to_wrap)
            wrapped_func = wrapped_fn_generator.send(result)

        # Copy the __doc__ attribute with the documentation in triple-quotes from
        # the decorated function.
        wrapped_func.__doc__ = getattr(fn_to_wrap, '__doc__', None)

        return wrapped_func
def _wrap(fn, parameter_type, wrapper_fn):
    """Wrap a given `fn` with a given `parameter_type` using `wrapper_fn`.

  This method does not handle the multiple modes of usage as wrapper/decorator,
  as those are handled by ComputationWrapper below. It focused on the simple
  case with a function/defun (always present) and either a valid parameter type
  or an indication that there's no parameter (None).

  The only ambiguity left to resolve is whether `fn` should be immediately
  wrapped, or treated as a polymorphic callable to be wrapped upon invocation
  based on actual parameter types. The determination is based on the presence
  or absence of parameters in the declaration of `fn`. In order to be
  treated as a concrete no-argument computation, `fn` shouldn't declare any
  arguments (even with default values).

  The `wrapper_fn` must accept three arguments, and optional forth kwarg `name`:

  * `target_fn'`, the Python function that to be wrapped, accepting possibly
    *args and **kwargs.

  * Either None for a no-parameter computation, or the type of the computation's
    parameter (an instance of `computation_types.Type`) if the computation has
    one.

  * `unpack`, an argument which will be passed on to
    `function_utils.wrap_as_zero_or_one_arg_callable` when wrapping `target_fn`.
    See that function for details.

  * Optional `name`, the name of the function that is being wrapped (only for
    debugging purposes).

  Args:
    fn: The function or defun to wrap as a computation.
    parameter_type: The parameter type accepted by the computation, or None if
      there is no parameter.
    wrapper_fn: The Python callable that performs actual wrapping. The object to
      be returned by this function should be an instance of a
      `ConcreteFunction`.

  Returns:
    Either the result of wrapping (an object that represents the computation),
    or a polymorphic callable that performs wrapping upon invocation based on
    argument types. The returned function still may accept multiple
    arguments (it has not yet had
    `function_uils.wrap_as_zero_or_one_arg_callable` applied to it).

  Raises:
    TypeError: if the arguments are of the wrong types, or the `wrapper_fn`
      constructs something that isn't a ConcreteFunction.
  """
    try:
        fn_name = fn.__name__
    except AttributeError:
        fn_name = None
    argspec = function_utils.get_argspec(fn)
    parameter_type = computation_types.to_type(parameter_type)
    if parameter_type is None:
        if (argspec.args or argspec.varargs or argspec.keywords):
            # There is no TFF type specification, and the function/defun declares
            # parameters. Create a polymorphic template.
            def _wrap_polymorphic(wrapper_fn,
                                  fn,
                                  parameter_type,
                                  name=fn_name):
                return wrapper_fn(fn, parameter_type, unpack=True, name=name)

            polymorphic_fn = function_utils.PolymorphicFunction(
                lambda pt: _wrap_polymorphic(wrapper_fn, fn, pt))

            # When applying a decorator, the __doc__ attribute with the documentation
            # in triple-quotes is not automatically transferred from the function on
            # which it was applied to the wrapped object, so we must transfer it here
            # explicitly.
            polymorphic_fn.__doc__ = getattr(fn, '__doc__', None)
            return polymorphic_fn

    # Either we have a concrete parameter type, or this is no-arg function.
    concrete_fn = wrapper_fn(fn, parameter_type, unpack=None)
    py_typecheck.check_type(concrete_fn, function_utils.ConcreteFunction,
                            'value returned by the wrapper')
    if not type_utils.are_equivalent_types(
            concrete_fn.type_signature.parameter, parameter_type):
        raise TypeError(
            'Expected a concrete function that takes parameter {}, got one '
            'that takes {}.'.format(str(parameter_type),
                                    str(concrete_fn.type_signature.parameter)))
    # When applying a decorator, the __doc__ attribute with the documentation
    # in triple-quotes is not automatically transferred from the function on
    concrete_fn.__doc__ = getattr(fn, '__doc__', None)
    return concrete_fn