Ejemplo n.º 1
0
def _check_returns_type_helper(fn, expected_return_type):
  """Helper for `check_returns_type`."""
  if not computation_wrapper.is_function(fn):
    raise ValueError(f'`assert_raises` expected a function, but found {fn}.')

  @functools.wraps(fn)
  def wrapped_func(*args, **kwargs):
    result = fn(*args, **kwargs)
    if result is None:
      raise ValueError('TFF computations may not return `None`. '
                       'Consider instead returning `()`.')
    result_type = type_conversions.infer_type(result)
    if not result_type.is_identical_to(expected_return_type):
      raise TypeError(
          f'Value returned from `{fn.__name__}` did not match asserted type.'
          f'Expected type:\n{expected_return_type}\n'
          f'Found type:\n{result_type}\n')
    return result

  return wrapped_func
def assert_returns(*args):
  """Asserts that the decorated function returns values of the provided type.

  Args:
    *args: Either a Python function, or TFF type spec, or both (function first).

  Returns:
    If invoked with a function as an argument, returns an instance of a TFF
    computation constructed based on this function. If called without one, as
    in the typical decorator style of usage, returns a callable that expects
    to be called with the function definition supplied as a parameter. See
    also `tff.tf_computation` for an extended documentation.
  """
  if not args:
    raise ValueError('`assert_return`s called without a return type')
  if computation_wrapper.is_function(args[0]):
    # If the first argument on the list is a Python function or a
    # tf.function, this is the one that's being wrapped. This is the case of
    # either a decorator invocation without arguments as "@xyz" applied to a
    # function definition, of an inline invocation as "... = xyz(lambda....).
    if len(args) != 2:
      raise ValueError(
          f'`assert_returns` expected two arguments: a function to decorate '
          f'and an expected return type. Found {len(args)} arguments: {args}')
    return _assert_returns_helper(args[0], computation_types.to_type(args[1]))
  else:
    # The function is being invoked as a decorator with arguments.
    # The arguments come first, then the returned value is applied to
    # the function to be wrapped.
    if len(args) != 1:
      raise ValueError(
          f'`assert_returns` expected a single argument specifying the '
          f'return type. Found {len(args)} arguments: {args}')
    return_type = computation_types.to_type(args[0])
    if return_type is None:
      raise ValueError('Asserted return type may not be `None`. '
                       'Consider instead a return type of `()`')
    return lambda fn: _assert_returns_helper(fn, return_type)
def _check_returns_type_helper(fn, expected_return_type):
  """Helper for `check_returns_type`."""
  if not computation_wrapper.is_function(fn):
    raise ValueError(f'`assert_raises` expected a function, but found {fn}.')

  @functools.wraps(fn)
  def wrapped_func(*args, **kwargs):
    result = fn(*args, **kwargs)
    if result is None:
      raise ValueError('TFF computations may not return `None`. '
                       'Consider instead returning `()`.')
    result_type = type_conversions.infer_type(result)
    if not result_type.is_identical_to(expected_return_type):
      raise TypeError(
          f'Value returned from `{fn.__name__}` did not match asserted type.\n'
          + computation_types.type_mismatch_error_message(
              result_type,
              expected_return_type,
              computation_types.TypeRelation.IDENTICAL,
              second_is_expected=True))
    return result

  return wrapped_func
def check_returns_type(*args):
  """Checks that the decorated function returns values of the provided type.

  This decorator can be used to ensure that a TFF computation returns a value
  of the expected type. For example:

  ```
  @tff.tf_computation(tf.int32, tf.int32)
  @tff.check_returns_type(tf.int32)
  def add(a, b):
    return a + b
  ```

  It can also be applied to non-TFF (Python) functions to ensure that the values
  they return conform to the expected type.

  Note that this assertion is run whenever the function is called. In the case
  of `@tff.tf_computation` and `@tff.federated_computation`s, this means that
  the assertion will run when the computation is traced. To enable this,
  `@tff.check_returns_type` should be applied *inside* the `tff.tf_computation`:

  ```
  # YES:
  @tff.tf_computation(...)
  @tff.check_returns_type(...)
  ...

  # NO:
  @tff.check_returns_type(...) # Don't put this before the line below
  @tff.tf_computation(...)
  ...
  ```

  Args:
    *args: Either a Python function, or TFF type spec, or both (function first).

  Returns:
    If invoked with a function as an argument, returns an instance of a TFF
    computation constructed based on this function. If called without one, as
    in the typical decorator style of usage, returns a callable that expects
    to be called with the function definition supplied as a parameter. See
    also `tff.tf_computation` for an extended documentation.
  """
  if not args:
    raise ValueError('`assert_return`s called without a return type')
  if computation_wrapper.is_function(args[0]):
    # If the first argument on the list is a Python function or a
    # tf.function, this is the one that's being wrapped. This is the case of
    # either a decorator invocation without arguments as "@xyz" applied to a
    # function definition, of an inline invocation as "... = xyz(lambda....).
    if len(args) != 2:
      raise ValueError(
          f'`check_returns_type` expected two arguments: a function to decorate '
          f'and an expected return type. Found {len(args)} arguments: {args}')
    return _check_returns_type_helper(args[0],
                                      computation_types.to_type(args[1]))
  else:
    # The function is being invoked as a decorator with arguments.
    # The arguments come first, then the returned value is applied to
    # the function to be wrapped.
    if len(args) != 1:
      raise ValueError(
          f'`check_returns_type` expected a single argument specifying the '
          f'return type. Found {len(args)} arguments: {args}')
    return_type = computation_types.to_type(args[0])
    if return_type is None:
      raise ValueError('Asserted return type may not be `None`. '
                       'Consider instead a return type of `()`')
    return lambda fn: _check_returns_type_helper(fn, return_type)