def test_polymorphic_function(self): class ContextForTest(context_base.Context): def ingest(self, val, type_spec): return val def invoke(self, comp, arg): return 'name={},type={},arg={}'.format( comp.name, str(comp.type_signature.parameter), str(arg)) class TestFunction(function_utils.ConcreteFunction): def __init__(self, name, parameter_type): self._name = name super().__init__( computation_types.FunctionType(parameter_type, tf.string), context_stack_impl.context_stack) @property def name(self): return self._name class TestFunctionFactory(object): def __init__(self): self._count = 0 def __call__(self, parameter_type): self._count = self._count + 1 return TestFunction(str(self._count), parameter_type) with context_stack_impl.context_stack.install(ContextForTest()): fn = function_utils.PolymorphicFunction(TestFunctionFactory()) self.assertEqual(fn(10), 'name=1,type=<int32>,arg=<10>') self.assertEqual( fn(20, x=True), 'name=2,type=<int32,x=bool>,arg=<20,x=True>') self.assertEqual(fn(True), 'name=3,type=<bool>,arg=<True>') self.assertEqual( fn(30, x=40), 'name=4,type=<int32,x=int32>,arg=<30,x=40>') self.assertEqual(fn(50), 'name=1,type=<int32>,arg=<50>') self.assertEqual( fn(0, x=False), 'name=2,type=<int32,x=bool>,arg=<0,x=False>') self.assertEqual(fn(False), 'name=3,type=<bool>,arg=<False>') self.assertEqual( fn(60, x=70), 'name=4,type=<int32,x=int32>,arg=<60,x=70>')
def test_call_returns_result(self): class TestContext(context_base.Context): def ingest(self, val, type_spec): return val def invoke(self, comp, arg): return 'name={},type={},arg={},unpack={}'.format( comp.name, comp.type_signature.parameter, arg, comp.unpack) class TestContextStack(context_stack_base.ContextStack): def __init__(self): super().__init__() self._context = TestContext() @property def current(self): return self._context def install(self, ctx): del ctx # Unused return self._context context_stack = TestContextStack() class TestFunction(function_utils.ConcreteFunction): def __init__(self, name, unpack, parameter_type): self._name = name self._unpack = unpack type_signature = computation_types.FunctionType( parameter_type, tf.string) super().__init__(type_signature, context_stack) @property def name(self): return self._name @property def unpack(self): return self._unpack class TestFunctionFactory(object): def __init__(self): self._count = 0 def __call__(self, parameter_type, unpack): self._count = self._count + 1 return TestFunction(str(self._count), str(unpack), parameter_type) fn = function_utils.PolymorphicFunction(TestFunctionFactory()) self.assertEqual(fn(10), 'name=1,type=<int32>,arg=<10>,unpack=True') self.assertEqual( fn(20, x=True), 'name=2,type=<int32,x=bool>,arg=<20,x=True>,unpack=True') fn_with_bool_arg = fn.fn_for_argument_type( computation_types.to_type(tf.bool)) self.assertEqual(fn_with_bool_arg(True), 'name=3,type=bool,arg=True,unpack=None') self.assertEqual( fn(30, x=40), 'name=4,type=<int32,x=int32>,arg=<30,x=40>,unpack=True') self.assertEqual(fn(50), 'name=1,type=<int32>,arg=<50>,unpack=True') self.assertEqual( fn(0, x=False), 'name=2,type=<int32,x=bool>,arg=<0,x=False>,unpack=True') fn_with_bool_arg = fn.fn_for_argument_type( computation_types.to_type(tf.bool)) self.assertEqual(fn_with_bool_arg(False), 'name=3,type=bool,arg=False,unpack=None') self.assertEqual( fn(60, x=70), 'name=4,type=<int32,x=int32>,arg=<60,x=70>,unpack=True')
def _wrap(fn, parameter_type, wrapper_fn): """Wraps a possibly-polymorphic `fn` in `wrapper_fn`. If `parameter_type` is `None` and `fn` takes any arguments (even with default values), `fn` is inferred to be polymorphic and won't be passed to `wrapper_fn` until invocation time (when concrete parameter types are available). `wrapper_fn` must accept three positional arguments and one defaulted argument `name`: * `target_fn`, the Python function to be wrapped. * `parameter_type`, the optional type of the computation's parameter (an instance of `computation_types.Type`). * `unpack`, an argument which will be passed on to `function_utils.wrap_as_zero_or_one_arg_callable` when wrapping `target_fn`. See that function for details. * Optional `name`, the name of the function that is being wrapped (only for debugging purposes). Args: fn: The function or defun to wrap as a computation. parameter_type: Optional type of any arguments to `fn`. wrapper_fn: The Python callable that performs actual wrapping. The object to be returned by this function should be an instance of a `ConcreteFunction`. Returns: Either the result of wrapping (an object that represents the computation), or a polymorphic callable that performs wrapping upon invocation based on argument types. The returned function still may accept multiple arguments (it has not yet had `function_uils.wrap_as_zero_or_one_arg_callable` applied to it). Raises: TypeError: if the arguments are of the wrong types, or the `wrapper_fn` constructs something that isn't a ConcreteFunction. """ try: fn_name = fn.__name__ except AttributeError: fn_name = None signature = function_utils.get_signature(fn) parameter_type = computation_types.to_type(parameter_type) if parameter_type is None and signature.parameters: # There is no TFF type specification, and the function/defun declares # parameters. Create a polymorphic template. def _wrap_polymorphic(parameter_type: computation_types.Type, unpack: Optional[bool]): return wrapper_fn(fn, parameter_type, unpack=unpack, name=fn_name) polymorphic_fn = function_utils.PolymorphicFunction(_wrap_polymorphic) # When applying a decorator, the __doc__ attribute with the documentation # in triple-quotes is not automatically transferred from the function on # which it was applied to the wrapped object, so we must transfer it here # explicitly. polymorphic_fn.__doc__ = getattr(fn, '__doc__', None) return polymorphic_fn # Either we have a concrete parameter type, or this is no-arg function. concrete_fn = wrapper_fn(fn, parameter_type, unpack=None) py_typecheck.check_type(concrete_fn, function_utils.ConcreteFunction, 'value returned by the wrapper') if (concrete_fn.type_signature.parameter is not None and not concrete_fn.type_signature.parameter.is_equivalent_to( parameter_type)): raise TypeError( 'Expected a concrete function that takes parameter {}, got one ' 'that takes {}.'.format(str(parameter_type), str(concrete_fn.type_signature.parameter))) # When applying a decorator, the __doc__ attribute with the documentation # in triple-quotes is not automatically transferred from the function on concrete_fn.__doc__ = getattr(fn, '__doc__', None) return concrete_fn
def __call__(self, *args, tff_internal_types=None): """Handles the different modes of usage of the decorator/wrapper. Args: *args: Positional arguments (the decorator at this point does not accept keyword arguments, although that might change in the future). tff_internal_types: TFF internal usage only. This argument should be considered private. Returns: Either a result of wrapping, or a callable that expects a function, method, or a tf.function and performs wrapping on it, depending on specific usage pattern. Raises: TypeError: if the arguments are of the wrong types. ValueError: if the function to wrap returns `None`. """ if not args or not is_function(args[0]): # If invoked as a decorator, and with an empty argument list as "@xyz()" # applied to a function definition, expect the Python function being # decorated to be passed in the subsequent call, and potentially create # a polymorphic callable. The parameter type is unspecified. # Deliberate wrapping with a lambda to prevent the caller from being able # to accidentally specify parameter type as a second argument. # The tricky partial recursion is needed to inline the logic in the # "success" case below. if tff_internal_types is not None: raise TypeError(f'Expected a function to wrap, found {args}.') provided_types = tuple(map(computation_types.to_type, args)) return functools.partial(self.__call__, tff_internal_types=provided_types) # If the first argument on the list is a Python function, instance method, # or a tf.function, this is the one that's being wrapped. This is the case # of either a decorator invocation without arguments as "@xyz" applied to # a function definition, of an inline invocation as # `... = xyz(lambda....).` # Any of the following arguments, if present, are the arguments to the # wrapper that are to be interpreted as the type specification. fn_to_wrap = args[0] if not tff_internal_types: tff_internal_types = tuple(map(computation_types.to_type, args[1:])) else: if len(args) > 1: raise TypeError( f'Expected no further arguments, found {args[1:]}.') parameter_types = tff_internal_types parameters = _parameters(fn_to_wrap) # NOTE: many of the properties checked here are only necessary for # non-polymorphic computations whose type signatures must be resolved # prior to use. However, we continue to enforce these requirements even # in the polymorphic case in order to avoid creating an inconsistency. _check_parameters(parameters) try: fn_name = fn_to_wrap.__name__ except AttributeError: fn_name = None if (not parameter_types) and parameters: # There is no TFF type specification, and the function/tf.function # declares parameters. Create a polymorphic template. def _polymorphic_wrapper(parameter_type: computation_types.Type, unpack: Optional[bool]): unpack_arguments_fn = function_utils.create_argument_unpacking_fn( fn_to_wrap, parameter_type, unpack=unpack) wrapped_fn_generator = _wrap_concrete(fn_name, self._wrapper_fn, parameter_type) args, kwargs = unpack_arguments_fn(next(wrapped_fn_generator)) result = fn_to_wrap(*args, **kwargs) if result is None: raise ComputationReturnedNoneError(fn_to_wrap) return wrapped_fn_generator.send(result) wrapped_func = function_utils.PolymorphicFunction( _polymorphic_wrapper) else: # Either we have a concrete parameter type, or this is no-arg function. parameter_type = _parameter_type(parameters, parameter_types) unpack_arguments_fn = function_utils.create_argument_unpacking_fn( fn_to_wrap, parameter_type, unpack=None) wrapped_fn_generator = _wrap_concrete(fn_name, self._wrapper_fn, parameter_type) args, kwargs = unpack_arguments_fn(next(wrapped_fn_generator)) result = fn_to_wrap(*args, **kwargs) if result is None: raise ComputationReturnedNoneError(fn_to_wrap) wrapped_func = wrapped_fn_generator.send(result) # Copy the __doc__ attribute with the documentation in triple-quotes from # the decorated function. wrapped_func.__doc__ = getattr(fn_to_wrap, '__doc__', None) return wrapped_func
def _wrap(fn, parameter_type, wrapper_fn): """Wrap a given `fn` with a given `parameter_type` using `wrapper_fn`. This method does not handle the multiple modes of usage as wrapper/decorator, as those are handled by ComputationWrapper below. It focused on the simple case with a function/defun (always present) and either a valid parameter type or an indication that there's no parameter (None). The only ambiguity left to resolve is whether `fn` should be immediately wrapped, or treated as a polymorphic callable to be wrapped upon invocation based on actual parameter types. The determination is based on the presence or absence of parameters in the declaration of `fn`. In order to be treated as a concrete no-argument computation, `fn` shouldn't declare any arguments (even with default values). The `wrapper_fn` must accept three arguments, and optional forth kwarg `name`: * `target_fn'`, the Python function that to be wrapped, accepting possibly *args and **kwargs. * Either None for a no-parameter computation, or the type of the computation's parameter (an instance of `computation_types.Type`) if the computation has one. * `unpack`, an argument which will be passed on to `function_utils.wrap_as_zero_or_one_arg_callable` when wrapping `target_fn`. See that function for details. * Optional `name`, the name of the function that is being wrapped (only for debugging purposes). Args: fn: The function or defun to wrap as a computation. parameter_type: The parameter type accepted by the computation, or None if there is no parameter. wrapper_fn: The Python callable that performs actual wrapping. The object to be returned by this function should be an instance of a `ConcreteFunction`. Returns: Either the result of wrapping (an object that represents the computation), or a polymorphic callable that performs wrapping upon invocation based on argument types. The returned function still may accept multiple arguments (it has not yet had `function_uils.wrap_as_zero_or_one_arg_callable` applied to it). Raises: TypeError: if the arguments are of the wrong types, or the `wrapper_fn` constructs something that isn't a ConcreteFunction. """ try: fn_name = fn.__name__ except AttributeError: fn_name = None argspec = function_utils.get_argspec(fn) parameter_type = computation_types.to_type(parameter_type) if parameter_type is None: if (argspec.args or argspec.varargs or argspec.keywords): # There is no TFF type specification, and the function/defun declares # parameters. Create a polymorphic template. def _wrap_polymorphic(wrapper_fn, fn, parameter_type, name=fn_name): return wrapper_fn(fn, parameter_type, unpack=True, name=name) polymorphic_fn = function_utils.PolymorphicFunction( lambda pt: _wrap_polymorphic(wrapper_fn, fn, pt)) # When applying a decorator, the __doc__ attribute with the documentation # in triple-quotes is not automatically transferred from the function on # which it was applied to the wrapped object, so we must transfer it here # explicitly. polymorphic_fn.__doc__ = getattr(fn, '__doc__', None) return polymorphic_fn # Either we have a concrete parameter type, or this is no-arg function. concrete_fn = wrapper_fn(fn, parameter_type, unpack=None) py_typecheck.check_type(concrete_fn, function_utils.ConcreteFunction, 'value returned by the wrapper') if not type_utils.are_equivalent_types( concrete_fn.type_signature.parameter, parameter_type): raise TypeError( 'Expected a concrete function that takes parameter {}, got one ' 'that takes {}.'.format(str(parameter_type), str(concrete_fn.type_signature.parameter))) # When applying a decorator, the __doc__ attribute with the documentation # in triple-quotes is not automatically transferred from the function on concrete_fn.__doc__ = getattr(fn, '__doc__', None) return concrete_fn