Exemplo n.º 1
0
  def federated_collect(self, value):
    """Implements `federated_collect` as defined in `api/intrinsics.py`."""
    value = value_impl.to_value(value, None, self._context_stack)
    value = value_utils.ensure_federated_value(value, placements.CLIENTS,
                                               'value to be collected')

    value = value_impl.ValueImpl.get_comp(value)
    comp = building_block_factory.create_federated_collect(value)
    return value_impl.ValueImpl(comp, self._context_stack)
Exemplo n.º 2
0
 def test_slicing_support_namedtuple(self):
     x = value_impl.ValueImpl(building_blocks.Reference('foo', tf.int32),
                              context_stack_impl.context_stack)
     y = value_impl.ValueImpl(building_blocks.Reference('bar', tf.bool),
                              context_stack_impl.context_stack)
     v = value_impl.to_value(
         collections.namedtuple('_', 'a b')(x, y), None,
         context_stack_impl.context_stack)
     sliced_v = v[:int(len(v) / 2)]
     self.assertIsInstance(sliced_v, value_base.Value)
     sliced_v = v[:4:2]
     self.assertEqual(str(sliced_v), '<foo>')
     self.assertIsInstance(sliced_v, value_base.Value)
     sliced_v = v[4::-1]
     self.assertEqual(str(sliced_v), '<bar,foo>')
     self.assertIsInstance(sliced_v, value_base.Value)
     with self.assertRaisesRegex(IndexError, 'slice 0 elements'):
         _ = v[2:4]
Exemplo n.º 3
0
 def test_value_impl_with_reference(self):
     x_comp = building_blocks.Reference('foo', tf.int32)
     x = value_impl.ValueImpl(x_comp, context_stack_impl.context_stack)
     self.assertIs(value_impl.ValueImpl.get_comp(x), x_comp)
     self.assertEqual(str(x.type_signature), 'int32')
     self.assertEqual(repr(x), 'Reference(\'foo\', TensorType(tf.int32))')
     self.assertEqual(str(x), 'foo')
     with self.assertRaises(SyntaxError):
         x(10)
Exemplo n.º 4
0
 def federated_sum(self, value):
     """Implements `federated_sum` as defined in `api/intrinsics.py`."""
     value = value_impl.to_value(value, None, self._context_stack)
     value = value_utils.ensure_federated_value(value, placements.CLIENTS,
                                                'value to be summed')
     type_utils.check_is_sum_compatible(value.type_signature)
     value = value_impl.ValueImpl.get_comp(value)
     comp = building_block_factory.create_federated_sum(value)
     return value_impl.ValueImpl(comp, self._context_stack)
Exemplo n.º 5
0
def zero_or_one_arg_fn_to_building_block(fn,
                                         parameter_name,
                                         parameter_type,
                                         context_stack,
                                         suggested_name=None):
    """Converts a zero- or one-argument `fn` into a computation building block.

  Args:
    fn: A function with 0 or 1 arguments that contains orchestration logic,
      i.e., that expects zero or one `values_base.Value` and returns a result
      convertible to the same.
    parameter_name: The name of the parameter, or `None` if there is't any.
    parameter_type: The TFF type of the parameter, or `None` if there's none.
    context_stack: The context stack to use.
    suggested_name: The optional suggested name to use for the federated context
      that will be used to serialize this function's body (ideally the name of
      the underlying Python function). It might be modified to avoid conflicts.
      If not `None`, it must be a string.

  Returns:
    An instance of `computation_building_blocks.ComputationBuildingBlock` that
    contains the logic from `fn`.

  Raises:
    ValueError: if `fn` is incompatible with `parameter_type`.
  """
    py_typecheck.check_callable(fn)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if suggested_name is not None:
        py_typecheck.check_type(suggested_name, six.string_types)
    parameter_type = computation_types.to_type(parameter_type)
    if isinstance(context_stack.current,
                  federated_computation_context.FederatedComputationContext):
        parent_context = context_stack.current
    else:
        parent_context = None
    context = federated_computation_context.FederatedComputationContext(
        context_stack, suggested_name=suggested_name, parent=parent_context)
    if parameter_name is not None:
        py_typecheck.check_type(parameter_name, six.string_types)
        parameter_name = '{}_{}'.format(context.name, str(parameter_name))
    with context_stack.install(context):
        if parameter_type is not None:
            result = fn(
                value_impl.ValueImpl(
                    computation_building_blocks.Reference(
                        parameter_name, parameter_type), context_stack))
        else:
            result = fn()
        result = value_impl.to_value(result, None, context_stack)
        result_comp = value_impl.ValueImpl.get_comp(result)
        if parameter_type is None:
            return result_comp
        else:
            return computation_building_blocks.Lambda(parameter_name,
                                                      parameter_type,
                                                      result_comp)
Exemplo n.º 6
0
 def test_to_value_for_dict(self):
     x = value_impl.ValueImpl(
         computation_building_blocks.Reference('foo', tf.int32),
         context_stack_impl.context_stack)
     y = value_impl.ValueImpl(
         computation_building_blocks.Reference('bar', tf.bool),
         context_stack_impl.context_stack)
     v1 = value_impl.to_value({
         'a': x,
         'b': y,
     }, None, context_stack_impl.context_stack)
     self.assertIsInstance(v1, value_base.Value)
     self.assertEqual(str(v1), '<a=foo,b=bar>')
     v2 = value_impl.to_value({
         'b': y,
         'a': x,
     }, None, context_stack_impl.context_stack)
     self.assertIsInstance(v2, value_base.Value)
     self.assertEqual(str(v2), '<a=foo,b=bar>')
Exemplo n.º 7
0
  def federated_value(self, value, placement):
    """Implements `federated_value` as defined in `api/intrinsics.py`."""
    # TODO(b/113112108): Verify that neither the value, nor any of its parts
    # are of a federated type.

    value = value_impl.to_value(value, None, self._context_stack)

    value = value_impl.ValueImpl.get_comp(value)
    comp = building_block_factory.create_federated_value(value, placement)
    return value_impl.ValueImpl(comp, self._context_stack)
Exemplo n.º 8
0
 def test_value_impl_with_call(self):
     x = value_impl.ValueImpl(
         computation_building_blocks.Reference(
             'foo', computation_types.FunctionType(tf.int32, tf.bool)),
         context_stack_impl.context_stack)
     y = value_impl.ValueImpl(
         computation_building_blocks.Reference('bar', tf.int32),
         context_stack_impl.context_stack)
     z = x(y)
     self.assertIsInstance(z, value_base.Value)
     self.assertEqual(str(z.type_signature), 'bool')
     self.assertEqual(str(z), 'foo(bar)')
     with self.assertRaises(TypeError):
         x()
     w = value_impl.ValueImpl(
         computation_building_blocks.Reference('bak', tf.float32),
         context_stack_impl.context_stack)
     with self.assertRaises(TypeError):
         x(w)
Exemplo n.º 9
0
  def federated_map(self, mapping_fn, value):
    """Implements `federated_map` as defined in `api/intrinsics.py`.

    Args:
      mapping_fn: As in `api/intrinsics.py`.
      value: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """

    # TODO(b/113112108): Possibly lift the restriction that the mapped value
    # must be placed at the clients after adding support for placement labels
    # in the federated types, and expanding the type specification of the
    # intrinsic this is based on to work with federated values of arbitrary
    # placement.

    value = value_impl.to_value(value, None, self._context_stack)
    if isinstance(value.type_signature, computation_types.NamedTupleType):
      if len(anonymous_tuple.to_elements(value.type_signature)) >= 2:
        # We've been passed a value which the user expects to be zipped.
        value = self.federated_zip(value)
    type_utils.check_federated_value_placement(value, placements.CLIENTS,
                                               'value to be mapped')

    # TODO(b/113112108): Add support for polymorphic templates auto-instantiated
    # here based on the actual type of the argument.
    mapping_fn = value_impl.to_value(mapping_fn, None, self._context_stack)

    py_typecheck.check_type(mapping_fn, value_base.Value)
    py_typecheck.check_type(mapping_fn.type_signature,
                            computation_types.FunctionType)
    if not type_utils.is_assignable_from(mapping_fn.type_signature.parameter,
                                         value.type_signature.member):
      raise TypeError(
          'The mapping function expects a parameter of type {}, but member '
          'constituents of the mapped value are of incompatible type {}.'
          .format(
              str(mapping_fn.type_signature.parameter),
              str(value.type_signature.member)))

    # TODO(b/113112108): Replace this as noted above.
    result_type = computation_types.FederatedType(
        mapping_fn.type_signature.result, placements.CLIENTS,
        value.type_signature.all_equal)
    intrinsic = value_impl.ValueImpl(
        computation_building_blocks.Intrinsic(
            intrinsic_defs.FEDERATED_MAP.uri,
            computation_types.FunctionType(
                [mapping_fn.type_signature, value.type_signature],
                result_type)), self._context_stack)
    return intrinsic(mapping_fn, value)
Exemplo n.º 10
0
    def federated_mean(self, value, weight):
        """Implements `federated_mean` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.
      weight: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
        # TODO(b/113112108): Possibly relax the constraints on numeric types, and
        # inject implicit casts where appropriate. For instance, we might want to
        # allow `tf.int32` values as the input, and automatically cast them to
        # `tf.float321 before invoking the average, thus producing a floating-point
        # result.

        # TODO(b/120439632): Possibly allow the weight to be either structured or
        # non-scalar, e.g., for the case of averaging a convolutional layer, when
        # we would want to use a different weight for every filter, and where it
        # might be cumbersome for users to have to manually slice and assemble a
        # variable.

        value = value_impl.to_value(value, None, self._context_stack)
        type_utils.check_federated_value_placement(value, placements.CLIENTS,
                                                   'value to be averaged')
        if not type_utils.is_average_compatible(value.type_signature):
            raise TypeError(
                'The value type {} is not compatible with the average operator.'
                .format(value.type_signature))

        if weight is not None:
            weight = value_impl.to_value(weight, None, self._context_stack)
            type_utils.check_federated_value_placement(
                weight, placements.CLIENTS, 'weight to use in averaging')
            py_typecheck.check_type(weight.type_signature.member,
                                    computation_types.TensorType)
            if weight.type_signature.member.shape.ndims != 0:
                raise TypeError(
                    'The weight type {} is not a federated scalar.'.format(
                        weight.type_signature))
            if not (weight.type_signature.member.dtype.is_integer
                    or weight.type_signature.member.dtype.is_floating):
                raise TypeError(
                    'The weight type {} is not a federated integer or floating-point '
                    'tensor.'.format(weight.type_signature))

        value = value_impl.ValueImpl.get_comp(value)
        if weight is not None:
            weight = value_impl.ValueImpl.get_comp(weight)
        comp = computation_constructing_utils.create_federated_mean(
            value, weight)
        return value_impl.ValueImpl(comp, self._context_stack)
Exemplo n.º 11
0
    def sequence_reduce(self, value, zero, op):
        """Implements `sequence_reduce` as defined in `api/intrinsics.py`."""
        value = value_impl.to_value(value, None, self._context_stack)
        zero = value_impl.to_value(zero, None, self._context_stack)
        op = value_impl.to_value(op, None, self._context_stack)
        if value.type_signature.is_sequence():
            element_type = value.type_signature.element
        else:
            py_typecheck.check_type(value.type_signature,
                                    computation_types.FederatedType)
            py_typecheck.check_type(value.type_signature.member,
                                    computation_types.SequenceType)
            element_type = value.type_signature.member.element
        op_type_expected = type_factory.reduction_op(zero.type_signature,
                                                     element_type)
        if not op_type_expected.is_assignable_from(op.type_signature):
            raise TypeError('Expected an operator of type {}, got {}.'.format(
                op_type_expected, op.type_signature))

        value = value_impl.ValueImpl.get_comp(value)
        zero = value_impl.ValueImpl.get_comp(zero)
        op = value_impl.ValueImpl.get_comp(op)
        if value.type_signature.is_sequence():
            comp = building_block_factory.create_sequence_reduce(
                value, zero, op)
            comp = self._bind_comp_as_reference(comp)
            return value_impl.ValueImpl(comp, self._context_stack)
        else:
            value_type = computation_types.SequenceType(element_type)
            intrinsic_type = computation_types.FunctionType((
                value_type,
                zero.type_signature,
                op.type_signature,
            ), op.type_signature.result)
            intrinsic = building_blocks.Intrinsic(
                intrinsic_defs.SEQUENCE_REDUCE.uri, intrinsic_type)
            ref = building_blocks.Reference('arg', value_type)
            tup = building_blocks.Struct((ref, zero, op))
            call = building_blocks.Call(intrinsic, tup)
            fn = building_blocks.Lambda(ref.name, ref.type_signature, call)
            fn_impl = value_impl.ValueImpl(fn, self._context_stack)
            return self.federated_map(fn_impl, value)
Exemplo n.º 12
0
 def generic_plus(arg):
   """Adds two arguments when possible."""
   x = arg[0]
   y = arg[1]
   _check_top_level_compatibility_with_generic_operators(x, y, 'Generic plus')
   # TODO(b/136587334): Push this logic down a level
   if isinstance(x.type_signature, computation_types.NamedTupleType):
     # This case is needed if federated types are nested deeply.
     names = [t[0] for t in anonymous_tuple.to_elements(x.type_signature)]
     added = [
         value_impl.ValueImpl.get_comp(generic_plus([x[i], y[i]]))
         for i in range(len(names))
     ]
     named_added = computation_constructing_utils.create_named_tuple(
         computation_building_blocks.Tuple(added), names)
     return value_impl.ValueImpl(named_added, context_stack)
   arg = _pack_binary_operator_args(x, y)
   arg_comp = value_impl.ValueImpl.get_comp(arg)
   added = intrinsic_utils.apply_binary_operator_with_upcast(arg_comp, tf.add)
   return value_impl.ValueImpl(added, context_stack)
Exemplo n.º 13
0
 def test_slicing_support_non_tuple_underlying_comp(self):
     test_computation_building_blocks = computation_building_blocks.Reference(
         'test', [tf.int32] * 5)
     v = value_impl.ValueImpl(test_computation_building_blocks,
                              context_stack_impl.context_stack)
     sliced_v = v[:4:2]
     self.assertIsInstance(sliced_v, value_base.Value)
     sliced_v = v[4:2:-1]
     self.assertIsInstance(sliced_v, value_base.Value)
     with self.assertRaisesRegex(IndexError, 'slice 0 elements'):
         _ = v[2:4:-1]
Exemplo n.º 14
0
 def _make_sequence_sum_for(type_spec):
   py_typecheck.check_type(type_spec, computation_types.SequenceType)
   if not type_utils.is_sum_compatible(type_spec.element):
     raise TypeError(
         'The value type {} is not compatible with the sum operator.'.format(
             str(type_spec)))
   return value_impl.ValueImpl(
       computation_building_blocks.Intrinsic(
           intrinsic_defs.SEQUENCE_SUM.uri,
           computation_types.FunctionType(type_spec, type_spec.element)),
       self._context_stack)
Exemplo n.º 15
0
  def federated_value(self, value, placement):
    """Implements `federated_value` as defined in `api/intrinsics.py`."""
    value = value_impl.to_value(value, None, self._context_stack)
    if type_analysis.contains(value.type_signature, lambda t: t.is_federated()):
      raise TypeError('Cannt place value {} containing federated types at '
                      'another placement; requested to be placed at {}.'.format(
                          value, placement))

    value = value_impl.ValueImpl.get_comp(value)
    comp = building_block_factory.create_federated_value(value, placement)
    comp = self._bind_comp_as_reference(comp)
    return value_impl.ValueImpl(comp, self._context_stack)
def federated_computation_serializer(
    parameter_name: Optional[str],
    parameter_type: Optional[computation_types.Type],
    context_stack: context_stack_base.ContextStack,
    suggested_name: Optional[str] = None,
):
  """Converts a function into a computation building block.

  Args:
    parameter_name: The name of the parameter, or `None` if there is't any.
    parameter_type: The `tff.Type` of the parameter, or `None` if there's none.
    context_stack: The context stack to use.
    suggested_name: The optional suggested name to use for the federated context
      that will be used to serialize this function's body (ideally the name of
      the underlying Python function). It might be modified to avoid conflicts.

  Yields:
    First, the argument to be passed to the function to be converted.
    Finally, a tuple of `(building_blocks.ComputationBuildingBlock,
    computation_types.Type)`: the function represented via building blocks and
    the inferred return type.
  """
  py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
  if suggested_name is not None:
    py_typecheck.check_type(suggested_name, str)
  if isinstance(context_stack.current,
                federated_computation_context.FederatedComputationContext):
    parent_context = context_stack.current
  else:
    parent_context = None
  context = federated_computation_context.FederatedComputationContext(
      context_stack, suggested_name=suggested_name, parent=parent_context)
  if parameter_name is not None:
    py_typecheck.check_type(parameter_name, str)
    parameter_name = '{}_{}'.format(context.name, str(parameter_name))
  with context_stack.install(context):
    if parameter_type is None:
      result = yield None
    else:
      result = yield (value_impl.ValueImpl(
          building_blocks.Reference(parameter_name, parameter_type),
          context_stack))
    annotated_result_type = type_conversions.infer_type(result)
    result = value_impl.to_value(result, annotated_result_type, context_stack)
    result_comp = value_impl.ValueImpl.get_comp(result)
    symbols_bound_in_context = context_stack.current.symbol_bindings
    if symbols_bound_in_context:
      result_comp = building_blocks.Block(
          local_symbols=symbols_bound_in_context, result=result_comp)
    annotated_type = computation_types.FunctionType(parameter_type,
                                                    annotated_result_type)
    yield building_blocks.Lambda(parameter_name, parameter_type,
                                 result_comp), annotated_type
Exemplo n.º 17
0
  def federated_broadcast(self, value):
    """Implements `federated_broadcast` as defined in `api/intrinsics.py`."""
    value = value_impl.to_value(value, None, self._context_stack)
    value = value_utils.ensure_federated_value(value, placements.SERVER,
                                               'value to be broadcasted')

    if not value.type_signature.all_equal:
      raise TypeError('The broadcasted value should be equal at all locations.')

    value = value_impl.ValueImpl.get_comp(value)
    comp = building_block_factory.create_federated_broadcast(value)
    return value_impl.ValueImpl(comp, self._context_stack)
Exemplo n.º 18
0
  def federated_value(self, value, placement):
    """Implements `federated_value` as defined in `api/intrinsics.py`."""
    value = value_impl.to_value(value, None, self._context_stack)
    if type_utils.type_tree_contains_types(value.type_signature,
                                           computation_types.FederatedType):
      raise TypeError('Cannt place value {} containing federated types at '
                      'another placement; requested to be placed at {}.'.format(
                          value, placement))

    value = value_impl.ValueImpl.get_comp(value)
    comp = building_block_factory.create_federated_value(value, placement)
    return value_impl.ValueImpl(comp, self._context_stack)
Exemplo n.º 19
0
 def generic_multiply(arg):
   """Multiplies two arguments when possible."""
   x = arg[0]
   y = arg[1]
   _check_top_level_compatibility_with_generic_operators(
       x, y, 'Generic multiply')
   if isinstance(x.type_signature, computation_types.NamedTupleType):
     # This case is needed if federated types are nested deeply.
     names = [t[0] for t in anonymous_tuple.to_elements(x.type_signature)]
     multiplied = [
         value_impl.ValueImpl.get_comp(generic_multiply([x[i], y[i]]))
         for i in range(len(names))
     ]
     named_multiplied = computation_constructing_utils.create_named_tuple(
         computation_building_blocks.Tuple(multiplied), names)
     return value_impl.ValueImpl(named_multiplied, context_stack)
   arg = _pack_binary_operator_args(x, y)
   arg_comp = value_impl.ValueImpl.get_comp(arg)
   multiplied = computation_constructing_utils.apply_binary_operator_with_upcast(
       arg_comp, tf.multiply)
   return value_impl.ValueImpl(multiplied, context_stack)
Exemplo n.º 20
0
 def test_value_impl_with_call(self):
     x = value_impl.ValueImpl(
         building_blocks.Reference(
             'foo', computation_types.FunctionType(tf.int32, tf.bool)),
         context_stack_impl.context_stack)
     y = value_impl.ValueImpl(building_blocks.Reference('bar', tf.int32),
                              context_stack_impl.context_stack)
     z = x(y)
     self.assertIsInstance(z, value_base.Value)
     self.assertEqual(str(z.type_signature), 'bool')
     self.assertEqual(str(z), 'fc_FEDERATED_symbol_0')
     bound_symbols = context_stack_impl.context_stack.current.symbol_bindings
     self.assertLen(bound_symbols, 1)
     self.assertEqual(bound_symbols[0][0], str(z))
     self.assertEqual(str(bound_symbols[0][1]), 'foo(bar)')
     with self.assertRaises(TypeError):
         x()
     w = value_impl.ValueImpl(building_blocks.Reference('bak', tf.float32),
                              context_stack_impl.context_stack)
     with self.assertRaises(TypeError):
         x(w)
Exemplo n.º 21
0
 def federated_secure_sum(self, value, bitwidth):
   """Implements `federated_secure_sum` as defined in `api/intrinsics.py`."""
   value = value_impl.to_value(value, None, self._context_stack)
   value = value_utils.ensure_federated_value(value, placements.CLIENTS,
                                              'value to be summed')
   type_utils.check_is_structure_of_integers(value.type_signature)
   bitwidth = value_impl.to_value(bitwidth, None, self._context_stack)
   type_utils.check_equivalent_types(value.type_signature.member,
                                     bitwidth.type_signature)
   value = value_impl.ValueImpl.get_comp(value)
   bitwidth = value_impl.ValueImpl.get_comp(bitwidth)
   comp = building_block_factory.create_federated_secure_sum(value, bitwidth)
   return value_impl.ValueImpl(comp, self._context_stack)
Exemplo n.º 22
0
    def federated_map(self, fn, arg):
        """Implements `federated_map` as defined in `api/intrinsics.py`."""
        # TODO(b/113112108): Possibly lift the restriction that the mapped value
        # must be placed at the server or clients. Would occur after adding support
        # for placement labels in the federated types, and expanding the type
        # specification of the intrinsic this is based on to work with federated
        # values of arbitrary placement.

        arg = value_impl.to_value(arg, None, self._context_stack)
        arg = value_utils.ensure_federated_value(arg,
                                                 label='value to be mapped')

        fn = value_impl.to_value(fn,
                                 None,
                                 self._context_stack,
                                 parameter_type_hint=arg.type_signature.member)

        py_typecheck.check_type(fn, value_base.Value)
        py_typecheck.check_type(fn.type_signature,
                                computation_types.FunctionType)
        if not fn.type_signature.parameter.is_assignable_from(
                arg.type_signature.member):
            raise TypeError(
                'The mapping function expects a parameter of type {}, but member '
                'constituents of the mapped value are of incompatible type {}.'
                .format(fn.type_signature.parameter,
                        arg.type_signature.member))

        # TODO(b/144384398): Change structure to one that maps the placement type
        # to the building_block function that fits it, in a way that allows the
        # appropriate type checks.
        if arg.type_signature.placement is placement_literals.SERVER:
            if not arg.type_signature.all_equal:
                raise TypeError(
                    'Arguments placed at {} should be equal at all locations.'.
                    format(placement_literals.SERVER))
            fn = value_impl.ValueImpl.get_comp(fn)
            arg = value_impl.ValueImpl.get_comp(arg)
            comp = building_block_factory.create_federated_apply(fn, arg)
        elif arg.type_signature.placement is placement_literals.CLIENTS:
            fn = value_impl.ValueImpl.get_comp(fn)
            arg = value_impl.ValueImpl.get_comp(arg)
            comp = building_block_factory.create_federated_map(fn, arg)
        else:
            raise TypeError(
                'The argument should be placed at {} or {}, placed at {} instead.'
                .format(placement_literals.SERVER, placement_literals.CLIENTS,
                        arg.type_signature.placement))

        comp = self._bind_comp_as_reference(comp)
        return value_impl.ValueImpl(comp, self._context_stack)
Exemplo n.º 23
0
    def federated_sum(self, value):
        """Implements `federated_sum` as defined in `api/intrinsics.py`."""
        value = value_impl.to_value(value, None, self._context_stack)
        value = value_utils.ensure_federated_value(value, placements.CLIENTS,
                                                   'value to be summed')

        if not type_utils.is_sum_compatible(value.type_signature):
            raise TypeError(
                'The value type {} is not compatible with the sum operator.'.
                format(value.type_signature))

        value = value_impl.ValueImpl.get_comp(value)
        comp = building_block_factory.create_federated_sum(value)
        return value_impl.ValueImpl(comp, self._context_stack)
Exemplo n.º 24
0
  def test_get_curried(self):
    add_numbers = value_impl.ValueImpl(
        building_blocks.ComputationBuildingBlock.from_proto(
            computation_impl.ComputationImpl.get_proto(
                computations.tf_computation(tf.add, [tf.int32, tf.int32]))),
        _context_stack)

    curried = value_utils.get_curried(add_numbers)
    self.assertEqual(str(curried.type_signature), '(int32 -> (int32 -> int32))')

    comp, _ = tree_transformations.uniquify_compiled_computation_names(
        value_impl.ValueImpl.get_comp(curried))
    self.assertEqual(comp.compact_representation(),
                     '(arg0 -> (arg1 -> comp#1(<arg0,arg1>)))')
Exemplo n.º 25
0
    def sequence_map(self, fn, arg):
        """Implements `sequence_map` as defined in `api/intrinsics.py`."""
        fn = value_impl.to_value(fn, None, self._context_stack)
        py_typecheck.check_type(fn.type_signature,
                                computation_types.FunctionType)
        arg = value_impl.to_value(arg, None, self._context_stack)

        if isinstance(arg.type_signature, computation_types.SequenceType):
            fn = value_impl.ValueImpl.get_comp(fn)
            arg = value_impl.ValueImpl.get_comp(arg)
            return value_impl.ValueImpl(
                building_block_factory.create_sequence_map(fn, arg),
                self._context_stack)
        elif isinstance(arg.type_signature, computation_types.FederatedType):
            parameter_type = computation_types.SequenceType(
                fn.type_signature.parameter)
            result_type = computation_types.SequenceType(
                fn.type_signature.result)
            intrinsic_type = computation_types.FunctionType(
                (fn.type_signature, parameter_type), result_type)
            intrinsic = building_blocks.Intrinsic(
                intrinsic_defs.SEQUENCE_MAP.uri, intrinsic_type)
            intrinsic_impl = value_impl.ValueImpl(intrinsic,
                                                  self._context_stack)
            local_fn = value_utils.get_curried(intrinsic_impl)(fn)

            if arg.type_signature.placement in [
                    placement_literals.SERVER, placement_literals.CLIENTS
            ]:
                return self.federated_map(local_fn, arg)
            else:
                raise TypeError('Unsupported placement {}.'.format(
                    arg.type_signature.placement))
        else:
            raise TypeError(
                'Cannot apply `tff.sequence_map()` to a value of type {}.'.
                format(arg.type_signature))
Exemplo n.º 26
0
    def sequence_sum(self, value):
        """Implements `sequence_sum` as defined in `api/intrinsics.py`."""
        value = value_impl.to_value(value, None, self._context_stack)
        if isinstance(value.type_signature, computation_types.SequenceType):
            element_type = value.type_signature.element
        else:
            py_typecheck.check_type(value.type_signature,
                                    computation_types.FederatedType)
            py_typecheck.check_type(value.type_signature.member,
                                    computation_types.SequenceType)
            element_type = value.type_signature.member.element
        type_analysis.check_is_sum_compatible(element_type)

        if isinstance(value.type_signature, computation_types.SequenceType):
            value = value_impl.ValueImpl.get_comp(value)
            return value_impl.ValueImpl(
                building_block_factory.create_sequence_sum(value),
                self._context_stack)
        elif isinstance(value.type_signature, computation_types.FederatedType):
            intrinsic_type = computation_types.FunctionType(
                value.type_signature.member,
                value.type_signature.member.element)
            intrinsic = building_blocks.Intrinsic(
                intrinsic_defs.SEQUENCE_SUM.uri, intrinsic_type)
            intrinsic_impl = value_impl.ValueImpl(intrinsic,
                                                  self._context_stack)
            if value.type_signature.placement in [
                    placement_literals.SERVER, placement_literals.CLIENTS
            ]:
                return self.federated_map(intrinsic_impl, value)
            else:
                raise TypeError('Unsupported placement {}.'.format(
                    value.type_signature.placement))
        else:
            raise TypeError(
                'Cannot apply `tff.sequence_sum()` to a value of type {}.'.
                format(value.type_signature))
Exemplo n.º 27
0
    def federated_reduce(self, value, zero, op):
        """Implements `federated_reduce` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.
      zero: As in `api/intrinsics.py`.
      op: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
        # TODO(b/113112108): Since in most cases, it can be assumed that CLIENTS is
        # a non-empty collective (or else, the computation fails), specifying zero
        # at this level of the API should probably be optional. TBD.

        value = value_impl.to_value(value, None, self._context_stack)
        type_utils.check_federated_value_placement(value, placements.CLIENTS,
                                                   'value to be reduced')

        zero = value_impl.to_value(zero, None, self._context_stack)
        py_typecheck.check_type(zero, value_base.Value)

        # TODO(b/113112108): We need a check here that zero does not have federated
        # constituents.

        op = value_impl.to_value(op, None, self._context_stack)
        py_typecheck.check_type(op, value_base.Value)
        py_typecheck.check_type(op.type_signature,
                                computation_types.FunctionType)
        op_type_expected = type_constructors.reduction_op(
            zero.type_signature, value.type_signature.member)
        if not type_utils.is_assignable_from(op_type_expected,
                                             op.type_signature):
            raise TypeError('Expected an operator of type {}, got {}.'.format(
                str(op_type_expected), str(op.type_signature)))

        # TODO(b/113112108): Replace this as noted above.
        result_type = computation_types.FederatedType(zero.type_signature,
                                                      placements.SERVER, True)
        intrinsic = value_impl.ValueImpl(
            computation_building_blocks.Intrinsic(
                intrinsic_defs.FEDERATED_REDUCE.uri,
                computation_types.FunctionType([
                    value.type_signature, zero.type_signature, op_type_expected
                ], result_type)), self._context_stack)
        return intrinsic(value, zero, op)
Exemplo n.º 28
0
    def federated_map_all_equal(self, fn, arg):
        """Implements `federated_map` as defined in `api/intrinsic.py`.

    Implements `federated_map` as defined in `api/intrinsic.py` with an argument
    with the `all_equal` bit set.

    Args:
      fn: As in `api/intrinsics.py`.
      arg: As in `api/intrinsics.py`, with the `all_equal` bit set.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
        # TODO(b/113112108): Possibly lift the restriction that the mapped value
        # must be placed at the clients after adding support for placement labels
        # in the federated types, and expanding the type specification of the
        # intrinsic this is based on to work with federated values of arbitrary
        # placement.

        arg = value_impl.to_value(arg, None, self._context_stack)
        if isinstance(arg.type_signature, computation_types.NamedTupleType):
            if len(anonymous_tuple.to_elements(arg.type_signature)) >= 2:
                # We've been passed a value which the user expects to be zipped.
                arg = self.federated_zip(arg)
        value_utils.check_federated_value_placement(arg, placements.CLIENTS,
                                                    'value to be mapped')

        # TODO(b/113112108): Add support for polymorphic templates auto-instantiated
        # here based on the actual type of the argument.
        fn = value_impl.to_value(fn, None, self._context_stack)

        py_typecheck.check_type(fn, value_base.Value)
        py_typecheck.check_type(fn.type_signature,
                                computation_types.FunctionType)
        if not type_utils.is_assignable_from(fn.type_signature.parameter,
                                             arg.type_signature.member):
            raise TypeError(
                'The mapping function expects a parameter of type {}, but member '
                'constituents of the mapped value are of incompatible type {}.'
                .format(fn.type_signature.parameter,
                        arg.type_signature.member))

        fn = value_impl.ValueImpl.get_comp(fn)
        arg = value_impl.ValueImpl.get_comp(arg)
        comp = building_block_factory.create_federated_map_all_equal(fn, arg)
        return value_impl.ValueImpl(comp, self._context_stack)
Exemplo n.º 29
0
    def federated_aggregate(self, value, zero, accumulate, merge, report):
        """Implements `federated_aggregate` as defined in `api/intrinsics.py`."""
        value = value_impl.to_value(value, None, self._context_stack)
        value = value_utils.ensure_federated_value(value, placements.CLIENTS,
                                                   'value to be aggregated')

        zero = value_impl.to_value(zero, None, self._context_stack)
        py_typecheck.check_type(zero, value_base.Value)
        accumulate = value_impl.to_value(accumulate, None, self._context_stack)
        merge = value_impl.to_value(merge, None, self._context_stack)
        report = value_impl.to_value(report, None, self._context_stack)
        for op in [accumulate, merge, report]:
            py_typecheck.check_type(op, value_base.Value)
            py_typecheck.check_type(op.type_signature,
                                    computation_types.FunctionType)

        if not type_utils.is_assignable_from(
                accumulate.type_signature.parameter[0], zero.type_signature):
            raise TypeError('Expected `zero` to be assignable to type {}, '
                            'but was of incompatible type {}.'.format(
                                accumulate.type_signature.parameter[0],
                                zero.type_signature))

        accumulate_type_expected = type_factory.reduction_op(
            accumulate.type_signature.result, value.type_signature.member)
        merge_type_expected = type_factory.reduction_op(
            accumulate.type_signature.result, accumulate.type_signature.result)
        report_type_expected = computation_types.FunctionType(
            merge.type_signature.result, report.type_signature.result)
        for op_name, op, type_expected in [
            ('accumulate', accumulate, accumulate_type_expected),
            ('merge', merge, merge_type_expected),
            ('report', report, report_type_expected)
        ]:
            if not type_utils.is_assignable_from(type_expected,
                                                 op.type_signature):
                raise TypeError(
                    'Expected parameter `{}` to be of type {}, but received {} instead.'
                    .format(op_name, type_expected, op.type_signature))

        value = value_impl.ValueImpl.get_comp(value)
        zero = value_impl.ValueImpl.get_comp(zero)
        accumulate = value_impl.ValueImpl.get_comp(accumulate)
        merge = value_impl.ValueImpl.get_comp(merge)
        report = value_impl.ValueImpl.get_comp(report)

        comp = building_block_factory.create_federated_aggregate(
            value, zero, accumulate, merge, report)
        return value_impl.ValueImpl(comp, self._context_stack)
Exemplo n.º 30
0
def zero_for(type_spec, context_stack):
  """Constructs ZERO intrinsic of TFF type `type_spec`.

  Args:
    type_spec: An instance of `types.Type` or something convertible to it.
      intrinsic.
    context_stack: The context stack to use.

  Returns:
    The `ZERO` intrinsic of the same TFF type as that of `val`.
  """
  type_spec = computation_types.to_type(type_spec)
  py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
  return value_impl.ValueImpl(
      construct_generic_constant(type_spec, 0), context_stack)