예제 #1
0
 def sequence_reduce(self, value, zero, op):
     """Implements `sequence_reduce` as defined in `api/intrinsics.py`."""
     value = value_impl.to_value(value, None, self._context_stack)
     zero = value_impl.to_value(zero, None, self._context_stack)
     op = value_impl.to_value(op, None, self._context_stack)
     # Check if the value is a federated sequence that should be reduced
     # under a `federated_map`.
     if value.type_signature.is_federated():
         is_federated_sequence = True
         value_member_type = value.type_signature.member
         value_member_type.check_sequence()
         zero_member_type = zero.type_signature.member
     else:
         is_federated_sequence = False
         value.type_signature.check_sequence()
     value = value_impl.ValueImpl.get_comp(value)
     zero = value_impl.ValueImpl.get_comp(zero)
     op = value_impl.ValueImpl.get_comp(op)
     if not is_federated_sequence:
         comp = building_block_factory.create_sequence_reduce(
             value, zero, op)
         comp = self._bind_comp_as_reference(comp)
         return value_impl.ValueImpl(comp, self._context_stack)
     else:
         ref_type = computation_types.StructType(
             [value_member_type, zero_member_type])
         ref = building_blocks.Reference('arg', ref_type)
         arg1 = building_blocks.Selection(ref, index=0)
         arg2 = building_blocks.Selection(ref, index=1)
         call = building_block_factory.create_sequence_reduce(
             arg1, arg2, op)
         fn = building_blocks.Lambda(ref.name, ref.type_signature, call)
         fn_value_impl = value_impl.ValueImpl(fn, self._context_stack)
         args = building_blocks.Struct([value, zero])
         return self.federated_map(fn_value_impl, args)
예제 #2
0
 def test_getattr_raises_federated_value_unknown_attr(self):
     federated_value_clients = value_impl.to_value(
         building_blocks.Reference(
             'test',
             computation_types.FederatedType([('a', tf.int32),
                                              ('b', tf.bool)],
                                             placements.CLIENTS, True)),
         None)
     self.assertEqual(str(federated_value_clients.type_signature),
                      '<a=int32,b=bool>@CLIENTS')
     with self.assertRaisesRegex(AttributeError,
                                 r'There is no such attribute \'c\''):
         _ = federated_value_clients.c
     federated_value_server = value_impl.to_value(
         building_blocks.Reference(
             'test',
             computation_types.FederatedType([('a', tf.int32),
                                              ('b', tf.bool)],
                                             placements.SERVER, True)),
         None)
     self.assertEqual(str(federated_value_server.type_signature),
                      '<a=int32,b=bool>@SERVER')
     with self.assertRaisesRegex(AttributeError,
                                 r'There is no such attribute \'c\''):
         _ = federated_value_server.c
예제 #3
0
    def sequence_map(self, fn, arg):
        """Implements `sequence_map` as defined in `api/intrinsics.py`."""
        fn = value_impl.to_value(fn, None, self._context_stack)
        py_typecheck.check_type(fn.type_signature,
                                computation_types.FunctionType)
        arg = value_impl.to_value(arg, None, self._context_stack)

        if arg.type_signature.is_sequence():
            fn = value_impl.ValueImpl.get_comp(fn)
            arg = value_impl.ValueImpl.get_comp(arg)
            comp = building_block_factory.create_sequence_map(fn, arg)
            comp = self._bind_comp_as_reference(comp)
            return value_impl.ValueImpl(comp, self._context_stack)
        elif arg.type_signature.is_federated():
            parameter_type = computation_types.SequenceType(
                fn.type_signature.parameter)
            result_type = computation_types.SequenceType(
                fn.type_signature.result)
            intrinsic_type = computation_types.FunctionType(
                (fn.type_signature, parameter_type), result_type)
            intrinsic = building_blocks.Intrinsic(
                intrinsic_defs.SEQUENCE_MAP.uri, intrinsic_type)
            intrinsic_impl = value_impl.ValueImpl(intrinsic,
                                                  self._context_stack)
            local_fn = value_utils.get_curried(intrinsic_impl)(fn)
            return self.federated_map(local_fn, arg)
        else:
            raise TypeError(
                'Cannot apply `tff.sequence_map()` to a value of type {}.'.
                format(arg.type_signature))
예제 #4
0
 def federated_secure_sum(self, value, bitwidth):
     """Implements `federated_secure_sum` as defined in `api/intrinsics.py`."""
     value = value_impl.to_value(value, None, self._context_stack)
     value = value_utils.ensure_federated_value(value,
                                                placement_literals.CLIENTS,
                                                'value to be summed')
     type_analysis.check_is_structure_of_integers(value.type_signature)
     bitwidth_value = value_impl.to_value(bitwidth, None,
                                          self._context_stack)
     value_member_type = value.type_signature.member
     bitwidth_type = bitwidth_value.type_signature
     if not type_analysis.is_valid_bitwidth_type_for_value_type(
             bitwidth_type, value_member_type):
         raise TypeError(
             'Expected `federated_secure_sum` parameter `bitwidth` to match '
             'the structure of `value`, with one integer bitwidth per tensor in '
             '`value`. Found `value` of `{}` and `bitwidth` of `{}`.'.
             format(value_member_type, bitwidth_type))
     if bitwidth_type.is_tensor() and value_member_type.is_struct():
         bitwidth_value = value_impl.to_value(
             structure.map_structure(lambda _: bitwidth, value_member_type),
             None, self._context_stack)
     value = value_impl.ValueImpl.get_comp(value)
     bitwidth_value = value_impl.ValueImpl.get_comp(bitwidth_value)
     comp = building_block_factory.create_federated_secure_sum(
         value, bitwidth_value)
     comp = self._bind_comp_as_reference(comp)
     return value_impl.ValueImpl(comp, self._context_stack)
예제 #5
0
def federated_map_all_equal(fn, arg):
    """`federated_map` with the `all_equal` bit set in the `arg` and return."""
    # TODO(b/113112108): Possibly lift the restriction that the mapped value
    # must be placed at the clients after adding support for placement labels
    # in the federated types, and expanding the type specification of the
    # intrinsic this is based on to work with federated values of arbitrary
    # placement.
    arg = value_impl.to_value(arg, None)
    arg = value_utils.ensure_federated_value(arg, placements.CLIENTS,
                                             'value to be mapped')

    fn = value_impl.to_value(fn,
                             None,
                             parameter_type_hint=arg.type_signature.member)

    py_typecheck.check_type(fn, value_impl.Value)
    py_typecheck.check_type(fn.type_signature, computation_types.FunctionType)
    if not fn.type_signature.parameter.is_assignable_from(
            arg.type_signature.member):
        raise TypeError(
            'The mapping function expects a parameter of type {}, but member '
            'constituents of the mapped value are of incompatible type {}.'.
            format(fn.type_signature.parameter, arg.type_signature.member))

    comp = building_block_factory.create_federated_map_all_equal(
        fn.comp, arg.comp)
    comp = _bind_comp_as_reference(comp)
    return value_impl.Value(comp)
예제 #6
0
def sequence_map(fn, arg):
    """Maps a TFF sequence `value` pointwise using a given function `fn`.

  This function supports two modes of usage:

  * When applied to a non-federated sequence, it maps individual elements of
    the sequence pointwise. If the supplied `fn` is of type `T->U` and
    the sequence `arg` is of type `T*` (a sequence of `T`-typed elements),
    the result is a sequence of type `U*` (a sequence of `U`-typed elements),
    with each element of the input sequence individually mapped by `fn`.
    In this mode of usage, `sequence_map` behaves like a compuatation with type
    signature `<T->U,T*> -> U*`.

  * When applied to a federated sequence, `sequence_map` behaves as if it were
    individually applied to each member constituent. In this mode of usage, one
    can think of `sequence_map` as a specialized variant of `federated_map` that
    is designed to work with sequences and allows one to
    specify a `fn` that operates at the level of individual elements.
    Indeed, under the hood, when `sequence_map` is invoked on a federated type,
    it injects `federated_map`, thus
    emitting expressions like
    `federated_map(a -> sequence_map(fn, x), arg)`.

  Args:
    fn: A mapping function to apply pointwise to elements of `arg`.
    arg: A value of a TFF type that is either a sequence, or a federated
      sequence.

  Returns:
    A sequence with the result of applying `fn` pointwise to each
    element of `arg`, or if `arg` was federated, a federated sequence
    with the result of invoking `sequence_map` on member sequences locally
    and independently at each location.

  Raises:
    TypeError: If the arguments are not of the appropriate types.
  """
    fn = value_impl.to_value(fn, None)
    py_typecheck.check_type(fn.type_signature, computation_types.FunctionType)
    arg = value_impl.to_value(arg, None)

    if arg.type_signature.is_sequence():
        comp = building_block_factory.create_sequence_map(fn.comp, arg.comp)
        comp = _bind_comp_as_reference(comp)
        return value_impl.Value(comp)
    elif arg.type_signature.is_federated():
        parameter_type = computation_types.SequenceType(
            fn.type_signature.parameter)
        result_type = computation_types.SequenceType(fn.type_signature.result)
        intrinsic_type = computation_types.FunctionType(
            (fn.type_signature, parameter_type), result_type)
        intrinsic = building_blocks.Intrinsic(intrinsic_defs.SEQUENCE_MAP.uri,
                                              intrinsic_type)
        intrinsic_impl = value_impl.Value(intrinsic)
        local_fn = value_utils.get_curried(intrinsic_impl)(fn)
        return federated_map(local_fn, arg)
    else:
        raise TypeError(
            'Cannot apply `tff.sequence_map()` to a value of type {}.'.format(
                arg.type_signature))
예제 #7
0
 def test_tf_mapping_raises_helpful_error(self):
     with self.assertRaisesRegex(
             TypeError, 'TensorFlow construct (.*) has been '
             'encountered in a federated context.'):
         _ = value_impl.to_value(tf.constant(10), None)
     with self.assertRaisesRegex(
             TypeError, 'TensorFlow construct (.*) has been '
             'encountered in a federated context.'):
         _ = value_impl.to_value(tf.Variable(np.array([10.0])), None)
예제 #8
0
    def federated_map(self, fn, arg):
        """Implements `federated_map` as defined in `api/intrinsics.py`."""
        # TODO(b/113112108): Possibly lift the restriction that the mapped value
        # must be placed at the server or clients. Would occur after adding support
        # for placement labels in the federated types, and expanding the type
        # specification of the intrinsic this is based on to work with federated
        # values of arbitrary placement.

        arg = value_impl.to_value(arg, None, self._context_stack)
        arg = value_utils.ensure_federated_value(arg,
                                                 label='value to be mapped')

        fn = value_impl.to_value(fn,
                                 None,
                                 self._context_stack,
                                 parameter_type_hint=arg.type_signature.member)

        py_typecheck.check_type(fn, value_base.Value)
        py_typecheck.check_type(fn.type_signature,
                                computation_types.FunctionType)
        if not fn.type_signature.parameter.is_assignable_from(
                arg.type_signature.member):
            raise TypeError(
                'The mapping function expects a parameter of type {}, but member '
                'constituents of the mapped value are of incompatible type {}.'
                .format(fn.type_signature.parameter,
                        arg.type_signature.member))

        # TODO(b/144384398): Change structure to one that maps the placement type
        # to the building_block function that fits it, in a way that allows the
        # appropriate type checks.
        if arg.type_signature.placement is placement_literals.SERVER:
            if not arg.type_signature.all_equal:
                raise TypeError(
                    'Arguments placed at {} should be equal at all locations.'.
                    format(placement_literals.SERVER))
            fn = value_impl.ValueImpl.get_comp(fn)
            arg = value_impl.ValueImpl.get_comp(arg)
            comp = building_block_factory.create_federated_apply(fn, arg)
        elif arg.type_signature.placement is placement_literals.CLIENTS:
            fn = value_impl.ValueImpl.get_comp(fn)
            arg = value_impl.ValueImpl.get_comp(arg)
            comp = building_block_factory.create_federated_map(fn, arg)
        else:
            raise TypeError(
                'Expected `arg` to have a type with a supported placement, '
                'found {}.'.format(arg.type_signature.placement))

        comp = self._bind_comp_as_reference(comp)
        return value_impl.ValueImpl(comp, self._context_stack)
예제 #9
0
 def test_to_value_for_dict(self):
     x = value_impl.Value(building_blocks.Reference('foo', tf.int32))
     y = value_impl.Value(building_blocks.Reference('bar', tf.bool))
     v1 = value_impl.to_value({
         'a': x,
         'b': y,
     }, None)
     self.assertIsInstance(v1, value_impl.Value)
     self.assertEqual(str(v1), '<a=foo,b=bar>')
     v2 = value_impl.to_value({
         'b': y,
         'a': x,
     }, None)
     self.assertIsInstance(v2, value_impl.Value)
     self.assertEqual(str(v2), '<a=foo,b=bar>')
예제 #10
0
 def test_to_value_for_ordered_dict(self):
     x = value_impl.Value(building_blocks.Reference('foo', tf.int32))
     y = value_impl.Value(building_blocks.Reference('bar', tf.bool))
     v = value_impl.to_value(collections.OrderedDict([('b', y), ('a', x)]),
                             None)
     self.assertIsInstance(v, value_impl.Value)
     self.assertEqual(str(v), '<b=bar,a=foo>')
예제 #11
0
def federated_value(value, placement):
  """Returns a federated value at `placement`, with `value` as the constituent.

  Deprecation warning: Using `tff.federated_value` with arguments other than
  simple Python constants is deprecated. When placing the result of a
  `tf_computation`, prefer `tff.federated_eval`.

  Args:
    value: A value of a non-federated TFF type to be placed.
    placement: The desired result placement (either `tff.SERVER` or
      `tff.CLIENTS`).

  Returns:
    A federated value with the given placement `placement`, and the member
    constituent `value` equal at all locations.

  Raises:
    TypeError: If the arguments are not of the appropriate types.
  """
  if isinstance(value, value_impl.Value):
    warnings.warn(
        'Deprecation warning: Using `tff.federated_value` with arguments '
        'other than simple Python constants is deprecated. When placing the '
        'result of a `tf_computation`, prefer `tff.federated_eval`.',
        DeprecationWarning)
  value = value_impl.to_value(value, None)
  if type_analysis.contains(value.type_signature, lambda t: t.is_federated()):
    raise TypeError('Cannot place value {} containing federated types at '
                    'another placement; requested to be placed at {}.'.format(
                        value, placement))

  comp = building_block_factory.create_federated_value(value.comp, placement)
  comp = _bind_comp_as_reference(comp)
  return value_impl.Value(comp)
예제 #12
0
def federated_broadcast(value):
    """Broadcasts a federated value from the `tff.SERVER` to the `tff.CLIENTS`.

  Args:
    value: A value of a TFF federated type placed at the `tff.SERVER`, all
      members of which are equal (the `tff.FederatedType.all_equal` property of
      `value` is `True`).

  Returns:
    A representation of the result of broadcasting: a value of a TFF federated
    type placed at the `tff.CLIENTS`, all members of which are equal.

  Raises:
    TypeError: If the argument is not a federated TFF value placed at the
      `tff.SERVER`.
  """
    value = value_impl.to_value(value, None)
    value = value_utils.ensure_federated_value(value, placements.SERVER,
                                               'value to be broadcasted')

    if not value.type_signature.all_equal:
        raise TypeError(
            'The broadcasted value should be equal at all locations.')

    comp = building_block_factory.create_federated_broadcast(value.comp)
    comp = _bind_comp_as_reference(comp)
    return value_impl.Value(comp)
예제 #13
0
def federated_eval(fn, placement):
    """Evaluates a federated computation at `placement`, returning the result.

  Args:
    fn: A no-arg TFF computation.
    placement: The desired result placement (either `tff.SERVER` or
      `tff.CLIENTS`).

  Returns:
    A federated value with the given placement `placement`.

  Raises:
    TypeError: If the arguments are not of the appropriate types.
  """
    # TODO(b/113112108): Verify that neither the value, nor any of its parts
    # are of a federated type.

    fn = value_impl.to_value(fn, None)
    py_typecheck.check_type(fn, value_impl.Value)
    py_typecheck.check_type(fn.type_signature, computation_types.FunctionType)

    if fn.type_signature.parameter is not None:
        raise TypeError(
            '`federated_eval` expects a `fn` that accepts no arguments, but '
            'the `fn` provided has a parameter of type {}.'.format(
                fn.type_signature.parameter))

    comp = building_block_factory.create_federated_eval(fn.comp, placement)
    comp = _bind_comp_as_reference(comp)
    return value_impl.Value(comp)
예제 #14
0
def federated_zip(value):
    """Converts an N-tuple of federated values into a federated N-tuple value.

  Args:
    value: A value of a TFF named tuple type, the elements of which are
      federated values with the same placement.

  Returns:
    A federated value placed at the same location as the members of `value`, in
    which every member component is a named tuple that consists of the
    corresponding member components of the elements of `value`.

  Raises:
    TypeError: If the argument is not a named tuple of federated values with the
      same placement.
  """
    # TODO(b/113112108): We use the iterate/unwrap approach below because
    # our type system is not powerful enough to express the concept of
    # "an operation that takes tuples of T of arbitrary length", and therefore
    # the intrinsic federated_zip must only take a fixed number of arguments,
    # here fixed at 2. There are other potential approaches to getting around
    # this problem (e.g. having the operator act on sequences and thereby
    # sidestepping the issue) which we may want to explore.
    value = value_impl.to_value(value, None)
    py_typecheck.check_type(value, value_impl.Value)
    py_typecheck.check_type(value.type_signature, computation_types.StructType)

    comp = building_block_factory.create_federated_zip(value.comp)
    comp = _bind_comp_as_reference(comp)
    return value_impl.Value(comp)
예제 #15
0
 def test_to_value_for_nested_attrs_class(self):
   x = value_impl.Value(building_blocks.Reference('foo', tf.int32))
   y = value_impl.Value(building_blocks.Reference('bar', tf.int32))
   v = value_impl.to_value(
       TestAttrClass(TestAttrClass(x, y), TestAttrClass(x, y)), None)
   self.assertIsInstance(v, value_impl.Value)
   self.assertEqual(str(v), '<x=<x=foo,y=bar>,y=<x=foo,y=bar>>')
예제 #16
0
    def sequence_sum(self, value):
        """Implements `sequence_sum` as defined in `api/intrinsics.py`."""
        value = value_impl.to_value(value, None, self._context_stack)
        if value.type_signature.is_sequence():
            element_type = value.type_signature.element
        else:
            py_typecheck.check_type(value.type_signature,
                                    computation_types.FederatedType)
            py_typecheck.check_type(value.type_signature.member,
                                    computation_types.SequenceType)
            element_type = value.type_signature.member.element
        type_analysis.check_is_sum_compatible(element_type)

        if value.type_signature.is_sequence():
            value = value_impl.ValueImpl.get_comp(value)
            comp = building_block_factory.create_sequence_sum(value)
            comp = self._bind_comp_as_reference(comp)
            return value_impl.ValueImpl(comp, self._context_stack)
        elif value.type_signature.is_federated():
            intrinsic_type = computation_types.FunctionType(
                value.type_signature.member,
                value.type_signature.member.element)
            intrinsic = building_blocks.Intrinsic(
                intrinsic_defs.SEQUENCE_SUM.uri, intrinsic_type)
            intrinsic_impl = value_impl.ValueImpl(intrinsic,
                                                  self._context_stack)
            return self.federated_map(intrinsic_impl, value)
        else:
            raise TypeError(
                'Cannot apply `tff.sequence_sum()` to a value of type {}.'.
                format(value.type_signature))
예제 #17
0
    def federated_mean(self, value, weight):
        """Implements `federated_mean` as defined in `api/intrinsics.py`."""
        # TODO(b/113112108): Possibly relax the constraints on numeric types, and
        # inject implicit casts where appropriate. For instance, we might want to
        # allow `tf.int32` values as the input, and automatically cast them to
        # `tf.float321 before invoking the average, thus producing a floating-point
        # result.

        # TODO(b/120439632): Possibly allow the weight to be either structured or
        # non-scalar, e.g., for the case of averaging a convolutional layer, when
        # we would want to use a different weight for every filter, and where it
        # might be cumbersome for users to have to manually slice and assemble a
        # variable.

        value = value_impl.to_value(value, None, self._context_stack)
        value = value_utils.ensure_federated_value(value,
                                                   placement_literals.CLIENTS,
                                                   'value to be averaged')
        if not type_analysis.is_average_compatible(value.type_signature):
            raise TypeError(
                'The value type {} is not compatible with the average operator.'
                .format(value.type_signature))

        if weight is not None:
            weight = value_impl.to_value(weight, None, self._context_stack)
            weight = value_utils.ensure_federated_value(
                weight, placement_literals.CLIENTS,
                'weight to use in averaging')
            py_typecheck.check_type(weight.type_signature.member,
                                    computation_types.TensorType)
            if weight.type_signature.member.shape.ndims != 0:
                raise TypeError(
                    'The weight type {} is not a federated scalar.'.format(
                        weight.type_signature))
            if not (weight.type_signature.member.dtype.is_integer
                    or weight.type_signature.member.dtype.is_floating):
                raise TypeError(
                    'The weight type {} is not a federated integer or floating-point '
                    'tensor.'.format(weight.type_signature))

        value = value_impl.ValueImpl.get_comp(value)
        if weight is not None:
            weight = value_impl.ValueImpl.get_comp(weight)
        comp = building_block_factory.create_federated_mean(value, weight)
        comp = self._bind_comp_as_reference(comp)
        return value_impl.ValueImpl(comp, self._context_stack)
예제 #18
0
def _federated_select(client_keys, max_key, server_val, select_fn, secure):
    """Internal helper for `federated_select` and `federated_secure_select`."""
    client_keys = value_impl.to_value(client_keys, None)
    _check_select_keys_type(client_keys.type_signature, secure)
    max_key = value_impl.to_value(max_key, None)
    expected_max_key_type = computation_types.at_server(tf.int32)
    if not expected_max_key_type.is_assignable_from(max_key.type_signature):
        _select_parameter_mismatch(
            max_key.type_signature,
            'a 32-bit unsigned integer placed at server',
            'max_key',
            secure,
            expected_type=expected_max_key_type)
    server_val = value_impl.to_value(server_val, None)
    server_val = value_utils.ensure_federated_value(server_val,
                                                    label='server_val')
    expected_server_val_type = computation_types.at_server(
        computation_types.AbstractType('T'))
    if (not server_val.type_signature.is_federated()
            or not server_val.type_signature.placement.is_server()):
        _select_parameter_mismatch(server_val.type_signature,
                                   'a value placed at server',
                                   'server_val',
                                   secure,
                                   expected_type=expected_server_val_type)
    select_fn_param_type = computation_types.to_type(
        [server_val.type_signature.member, tf.int32])
    select_fn = value_impl.to_value(select_fn,
                                    None,
                                    parameter_type_hint=select_fn_param_type)
    expected_select_fn_type = computation_types.FunctionType(
        select_fn_param_type, computation_types.AbstractType('U'))
    if (not select_fn.type_signature.is_function()
            or not select_fn.type_signature.parameter.is_assignable_from(
                select_fn_param_type)):
        _select_parameter_mismatch(select_fn.type_signature,
                                   'a function from state and key to result',
                                   'select_fn',
                                   secure,
                                   expected_type=expected_select_fn_type)
    comp = building_block_factory.create_federated_select(
        client_keys.comp, max_key.comp, server_val.comp, select_fn.comp,
        secure)
    comp = _bind_comp_as_reference(comp)
    return value_impl.Value(comp)
예제 #19
0
 def test_getattr_non_federated_value_with_none_default_missing_name(self):
     struct_value = value_impl.to_value(
         building_blocks.Reference(
             'test',
             computation_types.StructType([('a', tf.int32),
                                           ('b', tf.bool)])), None)
     self.assertEqual(str(struct_value.type_signature), '<a=int32,b=bool>')
     missing_attr = getattr(struct_value, 'c', None)
     self.assertIsNone(missing_attr)
예제 #20
0
 def test_getitem_resolution_federated_value_server(self):
   federated_value = value_impl.to_value(
       building_blocks.Reference(
           'test',
           computation_types.FederatedType([tf.int32, tf.bool],
                                           placements.SERVER, True)), None)
   self.assertEqual(str(federated_value.type_signature), '<int32,bool>@SERVER')
   federated_attribute = federated_value[0]
   self.assertEqual(str(federated_attribute.type_signature), 'int32@SERVER')
예제 #21
0
 def test_getattr_federated_value_with_none_default_missing_name(self):
   federated_value = value_impl.to_value(
       building_blocks.Reference(
           'test',
           computation_types.FederatedType([('a', tf.int32), ('b', tf.bool)],
                                           placements.SERVER, True)), None)
   self.assertEqual(
       str(federated_value.type_signature), '<a=int32,b=bool>@SERVER')
   missing_attr = getattr(federated_value, 'c', None)
   self.assertIsNone(missing_attr)
예제 #22
0
 def test_getattr_resolution_federated_value_clients(self):
   federated_value = value_impl.to_value(
       building_blocks.Reference(
           'test',
           computation_types.FederatedType([('a', tf.int32), ('b', tf.bool)],
                                           placements.CLIENTS, False)), None)
   self.assertEqual(
       str(federated_value.type_signature), '{<a=int32,b=bool>}@CLIENTS')
   federated_attribute = federated_value.a
   self.assertEqual(str(federated_attribute.type_signature), '{int32}@CLIENTS')
예제 #23
0
 def test_to_value_for_list(self):
   x = value_impl.ValueImpl(
       building_blocks.Reference('foo', tf.int32),
       context_stack_impl.context_stack)
   y = value_impl.ValueImpl(
       building_blocks.Reference('bar', tf.bool),
       context_stack_impl.context_stack)
   v = value_impl.to_value([x, y], None, context_stack_impl.context_stack)
   self.assertIsInstance(v, value_base.Value)
   self.assertEqual(str(v), '<foo,bar>')
예제 #24
0
 def test_getitem_federated_slice_constructs_comp_clients(self):
   federated_value = value_impl.to_value(
       building_blocks.Reference(
           'test',
           computation_types.FederatedType([tf.int32, tf.bool],
                                           placements.CLIENTS, False)), None)
   self.assertEqual(
       str(federated_value.type_signature), '{<int32,bool>}@CLIENTS')
   identity = federated_value[:]
   self.assertEqual(str(identity.type_signature), '{<int32,bool>}@CLIENTS')
   self.assertEqual(str(identity), 'federated_map(<(x -> <x[0],x[1]>),test>)')
예제 #25
0
 def test_getattr_resolution_federated_value_server(self):
   federated_value = value_impl.to_value(
       building_blocks.Reference(
           'test',
           computation_types.FederatedType([('a', tf.int32), ('b', tf.bool)],
                                           placement_literals.SERVER, True)),
       None, context_stack_impl.context_stack)
   self.assertEqual(
       str(federated_value.type_signature), '<a=int32,b=bool>@SERVER')
   federated_attribute = federated_value.a
   self.assertEqual(str(federated_attribute.type_signature), 'int32@SERVER')
예제 #26
0
 def federated_sum(self, value):
     """Implements `federated_sum` as defined in `api/intrinsics.py`."""
     value = value_impl.to_value(value, None, self._context_stack)
     value = value_utils.ensure_federated_value(value,
                                                placement_literals.CLIENTS,
                                                'value to be summed')
     type_analysis.check_is_sum_compatible(value.type_signature)
     value = value_impl.ValueImpl.get_comp(value)
     comp = building_block_factory.create_federated_sum(value)
     comp = self._bind_comp_as_reference(comp)
     return value_impl.ValueImpl(comp, self._context_stack)
예제 #27
0
 def test_getitem_federated_slice_constructs_comp_server(self):
   federated_value = value_impl.to_value(
       building_blocks.Reference(
           'test',
           computation_types.FederatedType([tf.int32, tf.bool],
                                           placements.SERVER, True)), None)
   self.assertEqual(str(federated_value.type_signature), '<int32,bool>@SERVER')
   identity = federated_value[:]
   self.assertEqual(str(identity.type_signature), '<int32,bool>@SERVER')
   self.assertEqual(
       str(identity), 'federated_apply(<(x -> <x[0],x[1]>),test>)')
 def invoke(self, comp, arg):
     fn = value_impl.to_value(comp, None)
     tys = fn.type_signature
     py_typecheck.check_type(tys, computation_types.FunctionType)
     if arg is not None:
         if tys.parameter is None:
             raise ValueError(
                 'A computation of type {} does not expect any arguments, but got '
                 'an argument {}.'.format(tys, arg))
         arg = value_impl.to_value(arg, tys.parameter, zip_if_needed=True)
         type_analysis.check_type(arg, tys.parameter)
         ret_val = fn(arg)
     else:
         if tys.parameter is not None:
             raise ValueError(
                 'A computation of type {} expects an argument of type {}, but got '
                 ' no argument.'.format(tys, tys.parameter))
         ret_val = fn()
     type_analysis.check_type(ret_val, tys.result)
     return ret_val
예제 #29
0
 def test_getitem_resolution_federated_value_clients(self):
   federated_value = value_impl.to_value(
       building_blocks.Reference(
           'test',
           computation_types.FederatedType([tf.int32, tf.bool],
                                           placement_literals.CLIENTS, False)),
       None, context_stack_impl.context_stack)
   self.assertEqual(
       str(federated_value.type_signature), '{<int32,bool>}@CLIENTS')
   federated_attribute = federated_value[0]
   self.assertEqual(str(federated_attribute.type_signature), '{int32}@CLIENTS')
예제 #30
0
def federated_computation_serializer(
    parameter_name: Optional[str],
    parameter_type: Optional[computation_types.Type],
    context_stack: context_stack_base.ContextStack,
    suggested_name: Optional[str] = None,
):
    """Converts a function into a computation building block.

  Args:
    parameter_name: The name of the parameter, or `None` if there is't any.
    parameter_type: The `tff.Type` of the parameter, or `None` if there's none.
    context_stack: The context stack to use.
    suggested_name: The optional suggested name to use for the federated context
      that will be used to serialize this function's body (ideally the name of
      the underlying Python function). It might be modified to avoid conflicts.

  Yields:
    First, the argument to be passed to the function to be converted.
    Finally, a tuple of `(building_blocks.ComputationBuildingBlock,
    computation_types.Type)`: the function represented via building blocks and
    the inferred return type.
  """
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if suggested_name is not None:
        py_typecheck.check_type(suggested_name, str)
    if isinstance(context_stack.current,
                  federated_computation_context.FederatedComputationContext):
        parent_context = context_stack.current
    else:
        parent_context = None
    context = federated_computation_context.FederatedComputationContext(
        context_stack, suggested_name=suggested_name, parent=parent_context)
    if parameter_name is not None:
        py_typecheck.check_type(parameter_name, str)
        parameter_name = '{}_{}'.format(context.name, str(parameter_name))
    with context_stack.install(context):
        if parameter_type is None:
            result = yield None
        else:
            result = yield (value_impl.ValueImpl(
                building_blocks.Reference(parameter_name, parameter_type),
                context_stack))
        annotated_result_type = type_conversions.infer_type(result)
        result = value_impl.to_value(result, annotated_result_type,
                                     context_stack)
        result_comp = value_impl.ValueImpl.get_comp(result)
        symbols_bound_in_context = context_stack.current.symbol_bindings
        if symbols_bound_in_context:
            result_comp = building_blocks.Block(
                local_symbols=symbols_bound_in_context, result=result_comp)
        annotated_type = computation_types.FunctionType(
            parameter_type, annotated_result_type)
        yield building_blocks.Lambda(parameter_name, parameter_type,
                                     result_comp), annotated_type