コード例 #1
0
def _federated_computation_wrapper_fn(target_fn,
                                      parameter_type,
                                      unpack,
                                      name=None):
  """Wrapper function to plug orchestration logic into the TFF framework.

  This function is passed through `computation_wrapper.ComputationWrapper`.
  Documentation its arguments can be found inside the definition of that class.
  """
  unpack_arguments = function_utils.create_argument_unpacking_fn(
      target_fn, parameter_type, unpack)
  ctx_stack = context_stack_impl.context_stack
  fn_generator = federated_computation_utils.federated_computation_serializer(
      'arg' if parameter_type else None,
      parameter_type,
      ctx_stack,
      suggested_name=name)
  args, kwargs = unpack_arguments(next(fn_generator))
  result = target_fn(*args, **kwargs)
  if result is None:
    line_number = target_fn.__code__.co_firstlineno
    filename = target_fn.__code__.co_filename
    raise ValueError(
        f'The function defined on line {line_number} of file {filename} '
        'returned `None`, but `federated_computation`s must return some '
        'non-`None` value.')
  target_lambda, extra_type_spec = fn_generator.send(result)
  return computation_impl.ComputationImpl(target_lambda.proto, ctx_stack,
                                          extra_type_spec)
コード例 #2
0
def _federated_computation_serializer(fn, parameter_name, parameter_type):
    unpack_arguments = function_utils.create_argument_unpacking_fn(
        fn, parameter_type)
    fn_gen = federated_computation_utils.federated_computation_serializer(
        parameter_name, parameter_type, context_stack_impl.context_stack)
    args, kwargs = unpack_arguments(next(fn_gen))
    result = fn(*args, **kwargs)
    return fn_gen.send(result)
コード例 #3
0
 def test_wrap_as_zero_or_one_arg_callable(self, unused_index, fn,
                                           parameter_type, unpack, arg,
                                           expected_result):
     parameter_type = computation_types.to_type(parameter_type)
     unpack_arguments = function_utils.create_argument_unpacking_fn(
         fn, parameter_type, unpack)
     args, kwargs = unpack_arguments(arg)
     actual_result = fn(*args, **kwargs)
     self.assertEqual(actual_result, expected_result)
コード例 #4
0
 def _polymorphic_wrapper(parameter_type: computation_types.Type,
                          unpack: Optional[bool]):
     unpack_arguments_fn = function_utils.create_argument_unpacking_fn(
         fn_to_wrap, parameter_type, unpack=unpack)
     wrapped_fn_generator = _wrap_concrete(fn_name,
                                           self._wrapper_fn,
                                           parameter_type)
     args, kwargs = unpack_arguments_fn(next(wrapped_fn_generator))
     result = fn_to_wrap(*args, **kwargs)
     if result is None:
         raise ComputationReturnedNoneError(fn_to_wrap)
     return wrapped_fn_generator.send(result)
コード例 #5
0
    def __init__(self, fn, parameter_type, unpack, name=None):
        del name
        unpack_args = function_utils.create_argument_unpacking_fn(
            fn, parameter_type, unpack)

        def _wrapped_fn(arg):
            args, kwargs = unpack_args(arg)
            return fn(*args, **kwargs)

        self._fn = _wrapped_fn
        super().__init__(
            computation_types.FunctionType(parameter_type, tf.string),
            context_stack_impl.context_stack)
コード例 #6
0
  def test_nested_structure_type_signature_roundtrip(self):

    def traced_fn(x):
      return x[0][0]

    param_type = computation_types.to_type([(np.int32,)])
    arg_fn = function_utils.create_argument_unpacking_fn(traced_fn, param_type)
    ctx_stack = context_stack_impl.context_stack
    comp_pb = jax_serialization.serialize_jax_computation(
        traced_fn, arg_fn, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(str(type_spec), '(<<int32>> -> int32)')
コード例 #7
0
  def test_arg_ordering(self):
    param_type = computation_types.to_type(
        (computation_types.TensorType(np.int32, 10),
         computation_types.TensorType(np.int32)))

    def traced_fn(b, a):
      return jax.numpy.add(a, jax.numpy.sum(b))

    arg_fn = function_utils.create_argument_unpacking_fn(traced_fn, param_type)
    ctx_stack = context_stack_impl.context_stack
    comp_pb = jax_serialization.serialize_jax_computation(
        traced_fn, arg_fn, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(str(type_spec), '(<int32[10],int32> -> int32)')
コード例 #8
0
  def test_serialize_jax_with_int32_to_int32(self):

    def traced_fn(x):
      return x + 10

    param_type = computation_types.to_type(np.int32)
    arg_fn = function_utils.create_argument_unpacking_fn(traced_fn, param_type)
    ctx_stack = context_stack_impl.context_stack
    comp_pb = jax_serialization.serialize_jax_computation(
        traced_fn, arg_fn, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(str(type_spec), '(int32 -> int32)')
    xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module)
    self.assertIn('ROOT tuple.6 = (s32[]) tuple(add.5)', xla_comp.as_hlo_text())
    self.assertEqual(str(comp_pb.xla.result), str(comp_pb.xla.parameter))
    self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' '  index: 0\n' '}\n')
コード例 #9
0
  def test_serialize_jax_with_2xint32_to_2xint32(self):

    def traced_fn(x):
      return collections.OrderedDict([('sum', x['foo'] + x['bar']),
                                      ('difference', x['bar'] - x['foo'])])

    param_type = computation_types.to_type(
        collections.OrderedDict([('foo', np.int32), ('bar', np.int32)]))
    arg_fn = function_utils.create_argument_unpacking_fn(traced_fn, param_type)
    ctx_stack = context_stack_impl.context_stack
    comp_pb = jax_serialization.serialize_jax_computation(
        traced_fn, arg_fn, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(
        str(type_spec),
        '(<foo=int32,bar=int32> -> <sum=int32,difference=int32>)')
    xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module)
    self.assertIn(
        # pylint: disable=line-too-long
        '  constant.4 = pred[] constant(false)\n'
        '  parameter.1 = (s32[], s32[]) parameter(0)\n'
        '  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n'
        '  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n'
        '  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n'
        '  subtract.6 = s32[] subtract(get-tuple-element.3, get-tuple-element.2)\n'
        '  ROOT tuple.7 = (s32[], s32[]) tuple(add.5, subtract.6)\n',
        xla_comp.as_hlo_text())
    self.assertEqual(str(comp_pb.xla.result), str(comp_pb.xla.parameter))
    self.assertEqual(
        str(comp_pb.xla.parameter), 'struct {\n'
        '  element {\n'
        '    tensor {\n'
        '      index: 0\n'
        '    }\n'
        '  }\n'
        '  element {\n'
        '    tensor {\n'
        '      index: 1\n'
        '    }\n'
        '  }\n'
        '}\n')
コード例 #10
0
def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None):
  """Wrapper function to plug Tensorflow logic into the TFF framework.

  This function is passed through `computation_wrapper.ComputationWrapper`.
  Documentation its arguments can be found inside the definition of that class.
  """
  del name  # Unused.
  unpack_arguments = function_utils.create_argument_unpacking_fn(
      target_fn, parameter_type, unpack)
  if not type_analysis.is_tensorflow_compatible_type(parameter_type):
    raise TypeError('`tf_computation`s can accept only parameter types with '
                    'constituents `SequenceType`, `StructType` '
                    'and `TensorType`; you have attempted to create one '
                    'with the type {}.'.format(parameter_type))
  ctx_stack = context_stack_impl.context_stack
  tf_serializer = tensorflow_serialization.tf_computation_serializer(
      parameter_type, ctx_stack)
  args, kwargs = unpack_arguments(next(tf_serializer))
  result = target_fn(*args, **kwargs)
  comp_pb, extra_type_spec = tf_serializer.send(result)
  return computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
コード例 #11
0
def _jax_strategy_fn(fn_to_wrap, fn_name, parameter_type, unpack):
    """Serializes a Python function containing JAX code as a TFF computation.

  Args:
    fn_to_wrap: The Python function containing JAX code to be serialized as a
      computation containing XLA.
    fn_name: The name for the constructed computation (currently ignored).
    parameter_type: An instance of `computation_types.Type` that represents the
      TFF type of the computation parameter, or `None` if there's none.
    unpack: See `unpack` in `function_utils.create_argument_unpacking_fn`.

  Returns:
    An instance of `computation_impl.ComputationImpl` with the constructed
    computation.
  """
    del fn_name  # Unused.
    unpack_arguments_fn = function_utils.create_argument_unpacking_fn(
        fn_to_wrap, parameter_type, unpack=unpack)
    ctx_stack = context_stack_impl.context_stack
    comp_pb = jax_serialization.serialize_jax_computation(
        fn_to_wrap, unpack_arguments_fn, parameter_type, ctx_stack)
    return computation_impl.ComputationImpl(comp_pb, ctx_stack)
コード例 #12
0
  def test_serialize_jax_with_two_args(self):

    def traced_fn(x, y):
      return x + y

    param_type = computation_types.to_type(
        collections.OrderedDict([('x', np.int32), ('y', np.int32)]))
    arg_fn = function_utils.create_argument_unpacking_fn(traced_fn, param_type)
    ctx_stack = context_stack_impl.context_stack
    comp_pb = jax_serialization.serialize_jax_computation(
        traced_fn, arg_fn, param_type, ctx_stack)
    self.assertIsInstance(comp_pb, pb.Computation)
    self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
    type_spec = type_serialization.deserialize_type(comp_pb.type)
    self.assertEqual(str(type_spec), '(<x=int32,y=int32> -> int32)')
    xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module)
    self.assertIn(
        # pylint: disable=line-too-long
        '  constant.4 = pred[] constant(false)\n'
        '  parameter.1 = (s32[], s32[]) parameter(0)\n'
        '  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n'
        '  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n'
        '  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n'
        '  ROOT tuple.6 = (s32[]) tuple(add.5)\n',
        xla_comp.as_hlo_text())
    self.assertEqual(
        str(comp_pb.xla.parameter), 'struct {\n'
        '  element {\n'
        '    tensor {\n'
        '      index: 0\n'
        '    }\n'
        '  }\n'
        '  element {\n'
        '    tensor {\n'
        '      index: 1\n'
        '    }\n'
        '  }\n'
        '}\n')
    self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' '  index: 0\n' '}\n')
コード例 #13
0
 def __call__(self, fn_to_wrap, fn_name, parameter_type, unpack):
     unpack_arguments_fn = function_utils.create_argument_unpacking_fn(
         fn_to_wrap, parameter_type, unpack=unpack)
     wrapped_fn_generator = _wrap_concrete(fn_name, self._wrapper_fn,
                                           parameter_type)
     packed_args = next(wrapped_fn_generator)
     try:
         args, kwargs = unpack_arguments_fn(packed_args)
         result = fn_to_wrap(*args, **kwargs)
         if result is None:
             raise ComputationReturnedNoneError(fn_to_wrap)
     except Exception:
         # Give nested generators an opportunity to clean up, then
         # re-raise the original error without extra context.
         # We don't want to simply pass the error into the generators,
         # as that would result in the whole generator stack being added
         # to the error message.
         try:
             wrapped_fn_generator.throw(_TracingError())
         except _TracingError:
             pass
         raise
     return wrapped_fn_generator.send(result)
コード例 #14
0
    def __call__(self, *args, tff_internal_types=None):
        """Handles the different modes of usage of the decorator/wrapper.

    Args:
      *args: Positional arguments (the decorator at this point does not accept
        keyword arguments, although that might change in the future).
      tff_internal_types: TFF internal usage only. This argument should be
        considered private.

    Returns:
      Either a result of wrapping, or a callable that expects a function,
      method, or a tf.function and performs wrapping on it, depending on
      specific usage pattern.

    Raises:
      TypeError: if the arguments are of the wrong types.
      ValueError: if the function to wrap returns `None`.
    """
        if not args or not is_function(args[0]):
            # If invoked as a decorator, and with an empty argument list as "@xyz()"
            # applied to a function definition, expect the Python function being
            # decorated to be passed in the subsequent call, and potentially create
            # a polymorphic callable. The parameter type is unspecified.
            # Deliberate wrapping with a lambda to prevent the caller from being able
            # to accidentally specify parameter type as a second argument.
            # The tricky partial recursion is needed to inline the logic in the
            # "success" case below.
            if tff_internal_types is not None:
                raise TypeError(f'Expected a function to wrap, found {args}.')
            provided_types = tuple(map(computation_types.to_type, args))
            return functools.partial(self.__call__,
                                     tff_internal_types=provided_types)
        # If the first argument on the list is a Python function, instance method,
        # or a tf.function, this is the one that's being wrapped. This is the case
        # of either a decorator invocation without arguments as "@xyz" applied to
        # a function definition, of an inline invocation as
        # `... = xyz(lambda....).`
        # Any of the following arguments, if present, are the arguments to the
        # wrapper that are to be interpreted as the type specification.
        fn_to_wrap = args[0]
        if not tff_internal_types:
            tff_internal_types = tuple(map(computation_types.to_type,
                                           args[1:]))
        else:
            if len(args) > 1:
                raise TypeError(
                    f'Expected no further arguments, found {args[1:]}.')

        parameter_types = tff_internal_types
        parameters = _parameters(fn_to_wrap)

        # NOTE: many of the properties checked here are only necessary for
        # non-polymorphic computations whose type signatures must be resolved
        # prior to use. However, we continue to enforce these requirements even
        # in the polymorphic case in order to avoid creating an inconsistency.
        _check_parameters(parameters)

        try:
            fn_name = fn_to_wrap.__name__
        except AttributeError:
            fn_name = None

        if (not parameter_types) and parameters:
            # There is no TFF type specification, and the function/tf.function
            # declares parameters. Create a polymorphic template.
            def _polymorphic_wrapper(parameter_type: computation_types.Type,
                                     unpack: Optional[bool]):
                unpack_arguments_fn = function_utils.create_argument_unpacking_fn(
                    fn_to_wrap, parameter_type, unpack=unpack)
                wrapped_fn_generator = _wrap_concrete(fn_name,
                                                      self._wrapper_fn,
                                                      parameter_type)
                args, kwargs = unpack_arguments_fn(next(wrapped_fn_generator))
                result = fn_to_wrap(*args, **kwargs)
                if result is None:
                    raise ComputationReturnedNoneError(fn_to_wrap)
                return wrapped_fn_generator.send(result)

            wrapped_func = function_utils.PolymorphicFunction(
                _polymorphic_wrapper)
        else:
            # Either we have a concrete parameter type, or this is no-arg function.
            parameter_type = _parameter_type(parameters, parameter_types)
            unpack_arguments_fn = function_utils.create_argument_unpacking_fn(
                fn_to_wrap, parameter_type, unpack=None)
            wrapped_fn_generator = _wrap_concrete(fn_name, self._wrapper_fn,
                                                  parameter_type)
            args, kwargs = unpack_arguments_fn(next(wrapped_fn_generator))
            result = fn_to_wrap(*args, **kwargs)
            if result is None:
                raise ComputationReturnedNoneError(fn_to_wrap)
            wrapped_func = wrapped_fn_generator.send(result)

        # Copy the __doc__ attribute with the documentation in triple-quotes from
        # the decorated function.
        wrapped_func.__doc__ = getattr(fn_to_wrap, '__doc__', None)

        return wrapped_func