def _check_returns_type_helper(fn, expected_return_type): """Helper for `check_returns_type`.""" if not computation_wrapper.is_function(fn): raise ValueError(f'`assert_raises` expected a function, but found {fn}.') @functools.wraps(fn) def wrapped_func(*args, **kwargs): result = fn(*args, **kwargs) if result is None: raise ValueError('TFF computations may not return `None`. ' 'Consider instead returning `()`.') result_type = type_conversions.infer_type(result) if not result_type.is_identical_to(expected_return_type): raise TypeError( f'Value returned from `{fn.__name__}` did not match asserted type.' f'Expected type:\n{expected_return_type}\n' f'Found type:\n{result_type}\n') return result return wrapped_func
def assert_returns(*args): """Asserts that the decorated function returns values of the provided type. Args: *args: Either a Python function, or TFF type spec, or both (function first). Returns: If invoked with a function as an argument, returns an instance of a TFF computation constructed based on this function. If called without one, as in the typical decorator style of usage, returns a callable that expects to be called with the function definition supplied as a parameter. See also `tff.tf_computation` for an extended documentation. """ if not args: raise ValueError('`assert_return`s called without a return type') if computation_wrapper.is_function(args[0]): # If the first argument on the list is a Python function or a # tf.function, this is the one that's being wrapped. This is the case of # either a decorator invocation without arguments as "@xyz" applied to a # function definition, of an inline invocation as "... = xyz(lambda....). if len(args) != 2: raise ValueError( f'`assert_returns` expected two arguments: a function to decorate ' f'and an expected return type. Found {len(args)} arguments: {args}') return _assert_returns_helper(args[0], computation_types.to_type(args[1])) else: # The function is being invoked as a decorator with arguments. # The arguments come first, then the returned value is applied to # the function to be wrapped. if len(args) != 1: raise ValueError( f'`assert_returns` expected a single argument specifying the ' f'return type. Found {len(args)} arguments: {args}') return_type = computation_types.to_type(args[0]) if return_type is None: raise ValueError('Asserted return type may not be `None`. ' 'Consider instead a return type of `()`') return lambda fn: _assert_returns_helper(fn, return_type)
def _check_returns_type_helper(fn, expected_return_type): """Helper for `check_returns_type`.""" if not computation_wrapper.is_function(fn): raise ValueError(f'`assert_raises` expected a function, but found {fn}.') @functools.wraps(fn) def wrapped_func(*args, **kwargs): result = fn(*args, **kwargs) if result is None: raise ValueError('TFF computations may not return `None`. ' 'Consider instead returning `()`.') result_type = type_conversions.infer_type(result) if not result_type.is_identical_to(expected_return_type): raise TypeError( f'Value returned from `{fn.__name__}` did not match asserted type.\n' + computation_types.type_mismatch_error_message( result_type, expected_return_type, computation_types.TypeRelation.IDENTICAL, second_is_expected=True)) return result return wrapped_func
def check_returns_type(*args): """Checks that the decorated function returns values of the provided type. This decorator can be used to ensure that a TFF computation returns a value of the expected type. For example: ``` @tff.tf_computation(tf.int32, tf.int32) @tff.check_returns_type(tf.int32) def add(a, b): return a + b ``` It can also be applied to non-TFF (Python) functions to ensure that the values they return conform to the expected type. Note that this assertion is run whenever the function is called. In the case of `@tff.tf_computation` and `@tff.federated_computation`s, this means that the assertion will run when the computation is traced. To enable this, `@tff.check_returns_type` should be applied *inside* the `tff.tf_computation`: ``` # YES: @tff.tf_computation(...) @tff.check_returns_type(...) ... # NO: @tff.check_returns_type(...) # Don't put this before the line below @tff.tf_computation(...) ... ``` Args: *args: Either a Python function, or TFF type spec, or both (function first). Returns: If invoked with a function as an argument, returns an instance of a TFF computation constructed based on this function. If called without one, as in the typical decorator style of usage, returns a callable that expects to be called with the function definition supplied as a parameter. See also `tff.tf_computation` for an extended documentation. """ if not args: raise ValueError('`assert_return`s called without a return type') if computation_wrapper.is_function(args[0]): # If the first argument on the list is a Python function or a # tf.function, this is the one that's being wrapped. This is the case of # either a decorator invocation without arguments as "@xyz" applied to a # function definition, of an inline invocation as "... = xyz(lambda....). if len(args) != 2: raise ValueError( f'`check_returns_type` expected two arguments: a function to decorate ' f'and an expected return type. Found {len(args)} arguments: {args}') return _check_returns_type_helper(args[0], computation_types.to_type(args[1])) else: # The function is being invoked as a decorator with arguments. # The arguments come first, then the returned value is applied to # the function to be wrapped. if len(args) != 1: raise ValueError( f'`check_returns_type` expected a single argument specifying the ' f'return type. Found {len(args)} arguments: {args}') return_type = computation_types.to_type(args[0]) if return_type is None: raise ValueError('Asserted return type may not be `None`. ' 'Consider instead a return type of `()`') return lambda fn: _check_returns_type_helper(fn, return_type)