示例#1
0
def sequence_map(fn, arg):
    """Maps a TFF sequence `value` pointwise using a given function `fn`.

  This function supports two modes of usage:

  * When applied to a non-federated sequence, it maps individual elements of
    the sequence pointwise. If the supplied `fn` is of type `T->U` and
    the sequence `arg` is of type `T*` (a sequence of `T`-typed elements),
    the result is a sequence of type `U*` (a sequence of `U`-typed elements),
    with each element of the input sequence individually mapped by `fn`.
    In this mode of usage, `sequence_map` behaves like a compuatation with type
    signature `<T->U,T*> -> U*`.

  * When applied to a federated sequence, `sequence_map` behaves as if it were
    individually applied to each member constituent. In this mode of usage, one
    can think of `sequence_map` as a specialized variant of `federated_map` that
    is designed to work with sequences and allows one to
    specify a `fn` that operates at the level of individual elements.
    Indeed, under the hood, when `sequence_map` is invoked on a federated type,
    it injects `federated_map`, thus
    emitting expressions like
    `federated_map(a -> sequence_map(fn, x), arg)`.

  Args:
    fn: A mapping function to apply pointwise to elements of `arg`.
    arg: A value of a TFF type that is either a sequence, or a federated
      sequence.

  Returns:
    A sequence with the result of applying `fn` pointwise to each
    element of `arg`, or if `arg` was federated, a federated sequence
    with the result of invoking `sequence_map` on member sequences locally
    and independently at each location.

  Raises:
    TypeError: If the arguments are not of the appropriate types.
  """
    fn = value_impl.to_value(fn, None)
    py_typecheck.check_type(fn.type_signature, computation_types.FunctionType)
    arg = value_impl.to_value(arg, None)

    if arg.type_signature.is_sequence():
        comp = building_block_factory.create_sequence_map(fn.comp, arg.comp)
        comp = _bind_comp_as_reference(comp)
        return value_impl.Value(comp)
    elif arg.type_signature.is_federated():
        parameter_type = computation_types.SequenceType(
            fn.type_signature.parameter)
        result_type = computation_types.SequenceType(fn.type_signature.result)
        intrinsic_type = computation_types.FunctionType(
            (fn.type_signature, parameter_type), result_type)
        intrinsic = building_blocks.Intrinsic(intrinsic_defs.SEQUENCE_MAP.uri,
                                              intrinsic_type)
        intrinsic_impl = value_impl.Value(intrinsic)
        local_fn = value_utils.get_curried(intrinsic_impl)(fn)
        return federated_map(local_fn, arg)
    else:
        raise TypeError(
            'Cannot apply `tff.sequence_map()` to a value of type {}.'.format(
                arg.type_signature))
示例#2
0
 def test_to_value_for_nested_attrs_class(self):
   x = value_impl.Value(building_blocks.Reference('foo', tf.int32))
   y = value_impl.Value(building_blocks.Reference('bar', tf.int32))
   v = value_impl.to_value(
       TestAttrClass(TestAttrClass(x, y), TestAttrClass(x, y)), None)
   self.assertIsInstance(v, value_impl.Value)
   self.assertEqual(str(v), '<x=<x=foo,y=bar>,y=<x=foo,y=bar>>')
示例#3
0
 def test_to_value_for_ordered_dict(self):
     x = value_impl.Value(building_blocks.Reference('foo', tf.int32))
     y = value_impl.Value(building_blocks.Reference('bar', tf.bool))
     v = value_impl.to_value(collections.OrderedDict([('b', y), ('a', x)]),
                             None)
     self.assertIsInstance(v, value_impl.Value)
     self.assertEqual(str(v), '<b=bar,a=foo>')
示例#4
0
 def test_value_impl_with_plus(self):
     x = value_impl.Value(building_blocks.Reference('x', tf.int32), )
     y = value_impl.Value(building_blocks.Reference('y', tf.int32))
     z = x + y
     self.assertIsInstance(z, value_impl.Value)
     self.assertEqual(str(z.type_signature), 'int32')
     self.assertEqual(str(z), 'fc_FEDERATED_symbol_0')
     bindings = self.bound_symbols()
     self.assertLen(bindings, 1)
     name, comp = bindings[0]
     self.assertEqual(name, 'fc_FEDERATED_symbol_0')
     self.assertEqual(comp.compact_representation(), 'generic_plus(<x,y>)')
示例#5
0
 def test_slicing_support_namedtuple(self):
     x = value_impl.Value(building_blocks.Reference('foo', tf.int32))
     y = value_impl.Value(building_blocks.Reference('bar', tf.bool))
     v = value_impl.to_value(collections.namedtuple('_', 'a b')(x, y), None)
     sliced_v = v[:int(len(v) / 2)]
     self.assertIsInstance(sliced_v, value_impl.Value)
     sliced_v = v[:4:2]
     self.assertEqual(str(sliced_v), '<foo>')
     self.assertIsInstance(sliced_v, value_impl.Value)
     sliced_v = v[4::-1]
     self.assertEqual(str(sliced_v), '<bar,foo>')
     self.assertIsInstance(sliced_v, value_impl.Value)
     with self.assertRaisesRegex(IndexError, 'slice 0 elements'):
         _ = v[2:4]
示例#6
0
 def test_to_value_for_dict(self):
     x = value_impl.Value(building_blocks.Reference('foo', tf.int32))
     y = value_impl.Value(building_blocks.Reference('bar', tf.bool))
     v1 = value_impl.to_value({
         'a': x,
         'b': y,
     }, None)
     self.assertIsInstance(v1, value_impl.Value)
     self.assertEqual(str(v1), '<a=foo,b=bar>')
     v2 = value_impl.to_value({
         'b': y,
         'a': x,
     }, None)
     self.assertIsInstance(v2, value_impl.Value)
     self.assertEqual(str(v2), '<a=foo,b=bar>')
示例#7
0
def federated_value(value, placement):
  """Returns a federated value at `placement`, with `value` as the constituent.

  Deprecation warning: Using `tff.federated_value` with arguments other than
  simple Python constants is deprecated. When placing the result of a
  `tf_computation`, prefer `tff.federated_eval`.

  Args:
    value: A value of a non-federated TFF type to be placed.
    placement: The desired result placement (either `tff.SERVER` or
      `tff.CLIENTS`).

  Returns:
    A federated value with the given placement `placement`, and the member
    constituent `value` equal at all locations.

  Raises:
    TypeError: If the arguments are not of the appropriate types.
  """
  if isinstance(value, value_impl.Value):
    warnings.warn(
        'Deprecation warning: Using `tff.federated_value` with arguments '
        'other than simple Python constants is deprecated. When placing the '
        'result of a `tf_computation`, prefer `tff.federated_eval`.',
        DeprecationWarning)
  value = value_impl.to_value(value, None)
  if type_analysis.contains(value.type_signature, lambda t: t.is_federated()):
    raise TypeError('Cannot place value {} containing federated types at '
                    'another placement; requested to be placed at {}.'.format(
                        value, placement))

  comp = building_block_factory.create_federated_value(value.comp, placement)
  comp = _bind_comp_as_reference(comp)
  return value_impl.Value(comp)
示例#8
0
def federated_eval(fn, placement):
    """Evaluates a federated computation at `placement`, returning the result.

  Args:
    fn: A no-arg TFF computation.
    placement: The desired result placement (either `tff.SERVER` or
      `tff.CLIENTS`).

  Returns:
    A federated value with the given placement `placement`.

  Raises:
    TypeError: If the arguments are not of the appropriate types.
  """
    # TODO(b/113112108): Verify that neither the value, nor any of its parts
    # are of a federated type.

    fn = value_impl.to_value(fn, None)
    py_typecheck.check_type(fn, value_impl.Value)
    py_typecheck.check_type(fn.type_signature, computation_types.FunctionType)

    if fn.type_signature.parameter is not None:
        raise TypeError(
            '`federated_eval` expects a `fn` that accepts no arguments, but '
            'the `fn` provided has a parameter of type {}.'.format(
                fn.type_signature.parameter))

    comp = building_block_factory.create_federated_eval(fn.comp, placement)
    comp = _bind_comp_as_reference(comp)
    return value_impl.Value(comp)
示例#9
0
 def test_value_impl_with_selection(self):
     x = value_impl.Value(
         building_blocks.Reference('foo', [('bar', tf.int32),
                                           ('baz', tf.bool)]))
     self.assertContainsSubset(['bar', 'baz'], dir(x))
     self.assertLen(x, 2)
     y = x.bar
     self.assertIsInstance(y, value_impl.Value)
     self.assertEqual(str(y.type_signature), 'int32')
     self.assertEqual(str(y), 'foo.bar')
     z = x['baz']
     self.assertEqual(str(z.type_signature), 'bool')
     self.assertEqual(str(z), 'foo.baz')
     with self.assertRaises(AttributeError):
         _ = x.bak
     x0 = x[0]
     self.assertIsInstance(x0, value_impl.Value)
     self.assertEqual(str(x0.type_signature), 'int32')
     self.assertEqual(str(x0), 'foo[0]')
     x1 = x[1]
     self.assertEqual(str(x1.type_signature), 'bool')
     self.assertEqual(str(x1), 'foo[1]')
     with self.assertRaises(IndexError):
         _ = x[2]
     with self.assertRaises(IndexError):
         _ = x[-1]
     self.assertEqual(','.join(str(e) for e in iter(x)), 'foo[0],foo[1]')
     self.assertEqual(','.join(str(e.type_signature) for e in iter(x)),
                      'int32,bool')
     with self.assertRaises(SyntaxError):
         x(10)
示例#10
0
def federated_zip(value):
    """Converts an N-tuple of federated values into a federated N-tuple value.

  Args:
    value: A value of a TFF named tuple type, the elements of which are
      federated values with the same placement.

  Returns:
    A federated value placed at the same location as the members of `value`, in
    which every member component is a named tuple that consists of the
    corresponding member components of the elements of `value`.

  Raises:
    TypeError: If the argument is not a named tuple of federated values with the
      same placement.
  """
    # TODO(b/113112108): We use the iterate/unwrap approach below because
    # our type system is not powerful enough to express the concept of
    # "an operation that takes tuples of T of arbitrary length", and therefore
    # the intrinsic federated_zip must only take a fixed number of arguments,
    # here fixed at 2. There are other potential approaches to getting around
    # this problem (e.g. having the operator act on sequences and thereby
    # sidestepping the issue) which we may want to explore.
    value = value_impl.to_value(value, None)
    py_typecheck.check_type(value, value_impl.Value)
    py_typecheck.check_type(value.type_signature, computation_types.StructType)

    comp = building_block_factory.create_federated_zip(value.comp)
    comp = _bind_comp_as_reference(comp)
    return value_impl.Value(comp)
示例#11
0
def federated_broadcast(value):
    """Broadcasts a federated value from the `tff.SERVER` to the `tff.CLIENTS`.

  Args:
    value: A value of a TFF federated type placed at the `tff.SERVER`, all
      members of which are equal (the `tff.FederatedType.all_equal` property of
      `value` is `True`).

  Returns:
    A representation of the result of broadcasting: a value of a TFF federated
    type placed at the `tff.CLIENTS`, all members of which are equal.

  Raises:
    TypeError: If the argument is not a federated TFF value placed at the
      `tff.SERVER`.
  """
    value = value_impl.to_value(value, None)
    value = value_utils.ensure_federated_value(value, placements.SERVER,
                                               'value to be broadcasted')

    if not value.type_signature.all_equal:
        raise TypeError(
            'The broadcasted value should be equal at all locations.')

    comp = building_block_factory.create_federated_broadcast(value.comp)
    comp = _bind_comp_as_reference(comp)
    return value_impl.Value(comp)
示例#12
0
def get_curried(fn):
  """Returns a curried version of function `fn` that takes a parameter tuple.

  For functions `fn` of types <T1,T2,....,Tn> -> U, the result is a function
  of the form T1 -> (T2 -> (T3 -> .... (Tn -> U) ... )).

  Note: No attempt is made at avoiding naming conflicts in cases where `fn`
  contains references. The arguments of the curriend function are named `argN`
  with `N` starting at 0.

  Args:
    fn: A value of a functional TFF type.

  Returns:
    A value that represents the curried form of `fn`.
  """
  py_typecheck.check_type(fn, value_impl.Value)
  fn.type_signature.check_function()
  fn.type_signature.parameter.check_struct()
  param_elements = structure.to_elements(fn.type_signature.parameter)
  references = []
  for idx, (_, elem_type) in enumerate(param_elements):
    references.append(building_blocks.Reference('arg{}'.format(idx), elem_type))
  result = building_blocks.Call(fn.comp, building_blocks.Struct(references))
  for ref in references[::-1]:
    result = building_blocks.Lambda(ref.name, ref.type_signature, result)
  return value_impl.Value(result)
示例#13
0
def federated_map_all_equal(fn, arg):
    """`federated_map` with the `all_equal` bit set in the `arg` and return."""
    # TODO(b/113112108): Possibly lift the restriction that the mapped value
    # must be placed at the clients after adding support for placement labels
    # in the federated types, and expanding the type specification of the
    # intrinsic this is based on to work with federated values of arbitrary
    # placement.
    arg = value_impl.to_value(arg, None)
    arg = value_utils.ensure_federated_value(arg, placements.CLIENTS,
                                             'value to be mapped')

    fn = value_impl.to_value(fn,
                             None,
                             parameter_type_hint=arg.type_signature.member)

    py_typecheck.check_type(fn, value_impl.Value)
    py_typecheck.check_type(fn.type_signature, computation_types.FunctionType)
    if not fn.type_signature.parameter.is_assignable_from(
            arg.type_signature.member):
        raise TypeError(
            'The mapping function expects a parameter of type {}, but member '
            'constituents of the mapped value are of incompatible type {}.'.
            format(fn.type_signature.parameter, arg.type_signature.member))

    comp = building_block_factory.create_federated_map_all_equal(
        fn.comp, arg.comp)
    comp = _bind_comp_as_reference(comp)
    return value_impl.Value(comp)
示例#14
0
    def test_value_impl_dir(self):
        x_comp = building_blocks.Reference('foo', tf.int32)
        x = value_impl.Value(x_comp)

        result = dir(x)
        self.assertIsInstance(result, list)
        self.assertNotEmpty(result)
        self.assertIn('type_signature', result)
示例#15
0
 def test_value_impl_with_reference(self):
     x_comp = building_blocks.Reference('foo', tf.int32)
     x = value_impl.Value(x_comp)
     self.assertIs(x.comp, x_comp)
     self.assertEqual(str(x.type_signature), 'int32')
     self.assertEqual(repr(x), 'Reference(\'foo\', TensorType(tf.int32))')
     self.assertEqual(str(x), 'foo')
     with self.assertRaises(SyntaxError):
         x(10)
示例#16
0
 def test_value_impl_with_call(self):
     x = value_impl.Value(
         building_blocks.Reference(
             'foo', computation_types.FunctionType(tf.int32, tf.bool)), )
     y = value_impl.Value(building_blocks.Reference('bar', tf.int32))
     z = x(y)
     self.assertIsInstance(z, value_impl.Value)
     self.assertEqual(str(z.type_signature), 'bool')
     self.assertEqual(str(z), 'fc_FEDERATED_symbol_0')
     bound_symbols = self.bound_symbols()
     self.assertLen(bound_symbols, 1)
     self.assertEqual(bound_symbols[0][0], str(z))
     self.assertEqual(str(bound_symbols[0][1]), 'foo(bar)')
     with self.assertRaises(TypeError):
         x()
     w = value_impl.Value(building_blocks.Reference('bak', tf.float32))
     with self.assertRaises(TypeError):
         x(w)
示例#17
0
 def test_slicing_support_non_tuple_underlying_comp(self):
     test_computation_building_blocks = building_blocks.Reference(
         'test', [tf.int32] * 5)
     v = value_impl.Value(test_computation_building_blocks)
     sliced_v = v[:4:2]
     self.assertIsInstance(sliced_v, value_impl.Value)
     sliced_v = v[4:2:-1]
     self.assertIsInstance(sliced_v, value_impl.Value)
     with self.assertRaisesRegex(IndexError, 'slice 0 elements'):
         _ = v[2:4:-1]
示例#18
0
def sequence_sum(value):
    """Computes a sum of elements in a sequence.

  Args:
    value: A value of a TFF type that is either a sequence, or a federated
      sequence.

  Returns:
    The sum of elements in the sequence. If the argument `value` is of a
    federated type, the result is also of a federated type, with the sum
    computed locally and independently at each location (see also a discussion
    on `sequence_map` and `sequence_reduce`).

  Raises:
    TypeError: If the arguments are of wrong or unsupported types.
  """
    value = value_impl.to_value(value, None)
    if value.type_signature.is_sequence():
        element_type = value.type_signature.element
    else:
        py_typecheck.check_type(value.type_signature,
                                computation_types.FederatedType)
        py_typecheck.check_type(value.type_signature.member,
                                computation_types.SequenceType)
        element_type = value.type_signature.member.element
    type_analysis.check_is_sum_compatible(element_type)

    if value.type_signature.is_sequence():
        comp = building_block_factory.create_sequence_sum(value.comp)
        comp = _bind_comp_as_reference(comp)
        return value_impl.Value(comp)
    elif value.type_signature.is_federated():
        intrinsic_type = computation_types.FunctionType(
            value.type_signature.member, value.type_signature.member.element)
        intrinsic = building_blocks.Intrinsic(intrinsic_defs.SEQUENCE_SUM.uri,
                                              intrinsic_type)
        intrinsic_impl = value_impl.Value(intrinsic)
        return federated_map(intrinsic_impl, value)
    else:
        raise TypeError(
            'Cannot apply `tff.sequence_sum()` to a value of type {}.'.format(
                value.type_signature))
示例#19
0
def federated_computation_serializer(
    parameter_name: Optional[str],
    parameter_type: Optional[computation_types.Type],
    context_stack: context_stack_base.ContextStack,
    suggested_name: Optional[str] = None,
):
    """Converts a function into a computation building block.

  Args:
    parameter_name: The name of the parameter, or `None` if there is't any.
    parameter_type: The `tff.Type` of the parameter, or `None` if there's none.
    context_stack: The context stack to use.
    suggested_name: The optional suggested name to use for the federated context
      that will be used to serialize this function's body (ideally the name of
      the underlying Python function). It might be modified to avoid conflicts.

  Yields:
    First, the argument to be passed to the function to be converted.
    Finally, a tuple of `(building_blocks.ComputationBuildingBlock,
    computation_types.Type)`: the function represented via building blocks and
    the inferred return type.
  """
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if suggested_name is not None:
        py_typecheck.check_type(suggested_name, str)
    if isinstance(context_stack.current,
                  federated_computation_context.FederatedComputationContext):
        parent_context = context_stack.current
    else:
        parent_context = None
    context = federated_computation_context.FederatedComputationContext(
        context_stack, suggested_name=suggested_name, parent=parent_context)
    if parameter_name is not None:
        py_typecheck.check_type(parameter_name, str)
        parameter_name = '{}_{}'.format(context.name, str(parameter_name))
    with context_stack.install(context):
        if parameter_type is None:
            result = yield None
        else:
            result = yield value_impl.Value(
                building_blocks.Reference(parameter_name, parameter_type))
        annotated_result_type = type_conversions.infer_type(result)
        result = value_impl.to_value(result, annotated_result_type,
                                     context_stack)
        result_comp = result.comp
        symbols_bound_in_context = context_stack.current.symbol_bindings
        if symbols_bound_in_context:
            result_comp = building_blocks.Block(
                local_symbols=symbols_bound_in_context, result=result_comp)
        annotated_type = computation_types.FunctionType(
            parameter_type, annotated_result_type)
        yield building_blocks.Lambda(parameter_name, parameter_type,
                                     result_comp), annotated_type
示例#20
0
 def test_value_impl_with_lambda(self):
   arg_name = 'arg'
   arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)),
               ('x', tf.int32)]
   result_value = (lambda arg: arg.f(arg.f(arg.x)))(
       value_impl.Value(building_blocks.Reference(arg_name, arg_type)))
   self.assertIsInstance(result_value, value_impl.Value)
   self.assertEqual(str(result_value.type_signature), 'int32')
   self.assertEqual(str(result_value), 'fc_FEDERATED_symbol_1')
   bound_symbols = self.bound_symbols()
   self.assertLen(bound_symbols, 2)
   self.assertEqual(bound_symbols[1][0], 'fc_FEDERATED_symbol_1')
   self.assertEqual(str(bound_symbols[1][1]), 'arg.f(fc_FEDERATED_symbol_0)')
   self.assertEqual(bound_symbols[0][0], 'fc_FEDERATED_symbol_0')
   self.assertEqual(str(bound_symbols[0][1]), 'arg.f(arg.x)')
示例#21
0
    def test_get_curried(self):
        add_numbers = value_impl.Value(
            building_blocks.ComputationBuildingBlock.from_proto(
                computation_impl.ComputationImpl.get_proto(
                    computations.tf_computation(
                        lambda a, b: tf.add(a, b),  # pylint: disable=unnecessary-lambda
                        [tf.int32, tf.int32]))))

        curried = value_utils.get_curried(add_numbers)
        self.assertEqual(str(curried.type_signature),
                         '(int32 -> (int32 -> int32))')

        comp, _ = tree_transformations.uniquify_compiled_computation_names(
            curried.comp)
        self.assertEqual(comp.compact_representation(),
                         '(arg0 -> (arg1 -> comp#1(<arg0,arg1>)))')
示例#22
0
    def test_get_curried(self):
        operand_type = computation_types.TensorType(tf.int32)
        computation_proto, type_signature = tensorflow_computation_factory.create_binary_operator(
            tf.add, operand_type, operand_type)
        building_block = building_blocks.CompiledComputation(
            proto=computation_proto,
            name='test',
            type_signature=type_signature)
        add_numbers = value_impl.Value(building_block)

        curried = value_utils.get_curried(add_numbers)

        self.assertEqual(curried.type_signature.compact_representation(),
                         '(int32 -> (int32 -> int32))')
        self.assertEqual(curried.comp.compact_representation(),
                         '(arg0 -> (arg1 -> comp#test(<arg0,arg1>)))')
示例#23
0
def _federated_select(client_keys, max_key, server_val, select_fn, secure):
    """Internal helper for `federated_select` and `federated_secure_select`."""
    client_keys = value_impl.to_value(client_keys, None)
    _check_select_keys_type(client_keys.type_signature, secure)
    max_key = value_impl.to_value(max_key, None)
    expected_max_key_type = computation_types.at_server(tf.int32)
    if not expected_max_key_type.is_assignable_from(max_key.type_signature):
        _select_parameter_mismatch(
            max_key.type_signature,
            'a 32-bit unsigned integer placed at server',
            'max_key',
            secure,
            expected_type=expected_max_key_type)
    server_val = value_impl.to_value(server_val, None)
    server_val = value_utils.ensure_federated_value(server_val,
                                                    label='server_val')
    expected_server_val_type = computation_types.at_server(
        computation_types.AbstractType('T'))
    if (not server_val.type_signature.is_federated()
            or not server_val.type_signature.placement.is_server()):
        _select_parameter_mismatch(server_val.type_signature,
                                   'a value placed at server',
                                   'server_val',
                                   secure,
                                   expected_type=expected_server_val_type)
    select_fn_param_type = computation_types.to_type(
        [server_val.type_signature.member, tf.int32])
    select_fn = value_impl.to_value(select_fn,
                                    None,
                                    parameter_type_hint=select_fn_param_type)
    expected_select_fn_type = computation_types.FunctionType(
        select_fn_param_type, computation_types.AbstractType('U'))
    if (not select_fn.type_signature.is_function()
            or not select_fn.type_signature.parameter.is_assignable_from(
                select_fn_param_type)):
        _select_parameter_mismatch(select_fn.type_signature,
                                   'a function from state and key to result',
                                   'select_fn',
                                   secure,
                                   expected_type=expected_select_fn_type)
    comp = building_block_factory.create_federated_select(
        client_keys.comp, max_key.comp, server_val.comp, select_fn.comp,
        secure)
    comp = _bind_comp_as_reference(comp)
    return value_impl.Value(comp)
示例#24
0
 def test_value_impl_with_tuple(self):
     x_comp = building_blocks.Reference('foo', tf.int32)
     y_comp = building_blocks.Reference('bar', tf.bool)
     z = value_impl.Value(building_blocks.Struct([x_comp, ('y', y_comp)]))
     self.assertIsInstance(z, value_impl.Value)
     self.assertEqual(str(z.type_signature), '<int32,y=bool>')
     self.assertEqual(str(z), '<foo,y=bar>')
     self.assertContainsSubset(['y'], dir(z))
     self.assertEqual(str(z.y), 'bar')
     self.assertIs(z.y.comp, y_comp)
     self.assertLen(z, 2)
     self.assertEqual(str(z[0]), 'foo')
     self.assertIs(z[0].comp, x_comp)
     self.assertEqual(str(z['y']), 'bar')
     self.assertIs(z['y'].comp, y_comp)
     self.assertEqual(','.join(str(e) for e in iter(z)), 'foo,bar')
     with self.assertRaises(SyntaxError):
         z(10)
示例#25
0
def ensure_federated_value(value, placement=None, label=None):
  """Ensures `value` is a federated value placed at `placement`.

  If `value` is not a `computation_types.FederatedType` but is a
  `computation_types.StructType` that can be converted via `federated_zip`
  to a `computation_types.FederatedType`, inserts the call to `federated_zip`
  and returns the result. If `value` cannot be converted, raises a TypeError.

  Args:
    value: A `value_impl.Value` to check and convert to a federated value if
      possible.
    placement: The expected placement. If None, any placement is allowed.
    label: An optional string label that describes `value`.

  Returns:
    The value as a federated value, automatically zipping if necessary.

  Raises:
    TypeError: if `value` is not a `FederatedType` and cannot be converted to
      a `FederatedType` with `federated_zip`.
  """
  py_typecheck.check_type(value, value_impl.Value)
  if label is not None:
    py_typecheck.check_type(label, str)

  if not value.type_signature.is_federated():
    comp = value.comp
    try:
      zipped = building_block_factory.create_federated_zip(comp)
    except (TypeError, ValueError):
      raise TypeError(
          'The {l} must be a FederatedType or implicitly convertible '
          'to a FederatedType (got a {t}).'.format(
              l=label if label else 'value', t=comp.type_signature))
    value = value_impl.Value(zipped)

  if placement and value.type_signature.placement is not placement:
    raise TypeError(
        'The {} should be placed at {}, but it is placed at {}.'.format(
            label if label else 'value', placement,
            value.type_signature.placement))

  return value
示例#26
0
def federated_collect(value):
  """Returns a federated value from `tff.CLIENTS` as a `tff.SERVER` sequence.

  Args:
    value: A value of a TFF federated type placed at the `tff.CLIENTS`.

  Returns:
    A stream of the same type as the member constituents of `value` placed at
    the `tff.SERVER`.

  Raises:
    TypeError: If the argument is not a federated TFF value placed at
      `tff.CLIENTS`.
  """
  value = value_impl.to_value(value, None)
  value = value_utils.ensure_federated_value(value, placements.CLIENTS,
                                             'value to be collected')

  comp = building_block_factory.create_federated_collect(value.comp)
  comp = _bind_comp_as_reference(comp)
  return value_impl.Value(comp)
示例#27
0
def federated_sum(value):
    """Computes a sum at `tff.SERVER` of a `value` placed on the `tff.CLIENTS`.

  To sum integer values with stronger privacy properties, consider using
  `tff.federated_secure_sum_bitwidth`.

  Args:
    value: A value of a TFF federated type placed at the `tff.CLIENTS`.

  Returns:
    A representation of the sum of the member constituents of `value` placed
    on the `tff.SERVER`.

  Raises:
    TypeError: If the argument is not a federated TFF value placed at
      `tff.CLIENTS`.
  """
    value = value_impl.to_value(value, None)
    value = value_utils.ensure_federated_value(value, placements.CLIENTS,
                                               'value to be summed')
    type_analysis.check_is_sum_compatible(value.type_signature)
    comp = building_block_factory.create_federated_sum(value.comp)
    comp = _bind_comp_as_reference(comp)
    return value_impl.Value(comp)
示例#28
0
def sequence_reduce(value, zero, op):
    """Reduces a TFF sequence `value` given a `zero` and reduction operator `op`.

  This method reduces a set of elements of a TFF sequence `value`, using a given
  `zero` in the algebra (i.e., the result of reducing an empty sequence) of some
  type `U`, and a reduction operator `op` with type signature `(<U,T> -> U)`
  that incorporates a single `T`-typed element of `value` into the `U`-typed
  result of partial reduction. In the special case of `T` equal to `U`, this
  corresponds to the classical notion of reduction of a set using a commutative
  associative binary operator. The generalized reduction (with `T` not equal to
  `U`) requires that repeated application of `op` to reduce a set of `T` always
  yields the same `U`-typed result, regardless of the order in which elements
  of `T` are processed in the course of the reduction.

  One can also invoke `sequence_reduce` on a federated sequence, in which case
  the reductions are performed pointwise; under the hood, we construct an
  expression  of the form
  `federated_map(x -> sequence_reduce(x, zero, op), value)`. See also the
  discussion on `sequence_map`.

  Note: When applied to a federated value this function does the reduce
  point-wise.

  Args:
    value: A value that is either a TFF sequence, or a federated sequence.
    zero: The result of reducing a sequence with no elements.
    op: An operator with type signature `(<U,T> -> U)`, where `T` is the type of
      the elements of the sequence, and `U` is the type of `zero` to be used in
      performing the reduction.

  Returns:
    The `U`-typed result of reducing elements in the sequence, or if the `value`
    is federated, a federated `U` that represents the result of locally
    reducing each member constituent of `value`.

  Raises:
    TypeError: If the arguments are not of the types specified above.
  """
    value = value_impl.to_value(value, None)
    zero = value_impl.to_value(zero, None)
    op = value_impl.to_value(op, None)
    # Check if the value is a federated sequence that should be reduced
    # under a `federated_map`.
    if value.type_signature.is_federated():
        is_federated_sequence = True
        value_member_type = value.type_signature.member
        value_member_type.check_sequence()
        zero_member_type = zero.type_signature.member
    else:
        is_federated_sequence = False
        value.type_signature.check_sequence()
    if not is_federated_sequence:
        comp = building_block_factory.create_sequence_reduce(
            value.comp, zero.comp, op.comp)
        comp = _bind_comp_as_reference(comp)
        return value_impl.Value(comp)
    else:
        ref_type = computation_types.StructType(
            [value_member_type, zero_member_type])
        ref = building_blocks.Reference('arg', ref_type)
        arg1 = building_blocks.Selection(ref, index=0)
        arg2 = building_blocks.Selection(ref, index=1)
        call = building_block_factory.create_sequence_reduce(
            arg1, arg2, op.comp)
        fn = building_blocks.Lambda(ref.name, ref.type_signature, call)
        fn_value_impl = value_impl.Value(fn)
        args = building_blocks.Struct([value.comp, zero.comp])
        return federated_map(fn_value_impl, args)
示例#29
0
def federated_secure_sum_bitwidth(value, bitwidth):
    """Computes a sum at `tff.SERVER` of a `value` placed on the `tff.CLIENTS`.

  This function computes a sum such that it should not be possible for the
  server to learn any clients individual value. The specific algorithm and
  mechanism used to compute the secure sum may vary depending on the target
  runtime environment the computation is compiled for or executed on. See
  https://research.google/pubs/pub47246/ for more information.

  Not all executors support `tff.federated_secure_sum_bitwidth()`; consult the
  documentation for the specific executor or executor stack you plan on using
  for the specific of how it's handled by that executor.

  The `bitwidth` argument represents the bitwidth of the aggregand, that is the
  bitwidth of the input `value`. The federated secure sum bitwidth (i.e., the
  bitwidth of the *sum* of the input `value`s over all clients) will be a
  function of this bitwidth and the number of participating clients.

  Example:

  ```python
  value = tff.federated_value(1, tff.CLIENTS)
  result = tff.federated_secure_sum_bitwidth(value, 2)

  value = tff.federated_value([1, 1], tff.CLIENTS)
  result = tff.federated_secure_sum_bitwidth(value, [2, 4])

  value = tff.federated_value([1, [1, 1]], tff.CLIENTS)
  result = tff.federated_secure_sum_bitwidth(value, [2, [4, 8]])
  ```

  Note: To sum non-integer values or to sum integers with fewer constraints and
  weaker privacy properties, consider using `federated_sum`.

  Args:
    value: An integer value of a TFF federated type placed at the `tff.CLIENTS`,
      in the range [0, 2^bitwidth - 1].
    bitwidth: An integer or nested structure of integers matching the structure
      of `value`. If integer `bitwidth` is used with a nested `value`, the same
      integer is used for each tensor in `value`.

  Returns:
    A representation of the sum of the member constituents of `value` placed
    on the `tff.SERVER`.

  Raises:
    TypeError: If the argument is not a federated TFF value placed at
      `tff.CLIENTS`.
  """
    value = value_impl.to_value(value, None)
    value = value_utils.ensure_federated_value(value, placements.CLIENTS,
                                               'value to be summed')
    type_analysis.check_is_structure_of_integers(value.type_signature)
    bitwidth_value = value_impl.to_value(bitwidth, None)
    value_member_type = value.type_signature.member
    bitwidth_type = bitwidth_value.type_signature
    if not type_analysis.is_single_integer_or_matches_structure(
            bitwidth_type, value_member_type):
        raise TypeError(
            'Expected `federated_secure_sum_bitwidth` parameter `bitwidth` to match '
            'the structure of `value`, with one integer bitwidth per tensor in '
            '`value`. Found `value` of `{}` and `bitwidth` of `{}`.'.format(
                value_member_type, bitwidth_type))
    if bitwidth_type.is_tensor() and value_member_type.is_struct():
        bitwidth_value = value_impl.to_value(
            structure.map_structure(lambda _: bitwidth, value_member_type),
            None)
    comp = building_block_factory.create_federated_secure_sum_bitwidth(
        value.comp, bitwidth_value.comp)
    comp = _bind_comp_as_reference(comp)
    return value_impl.Value(comp)
示例#30
0
def federated_secure_sum(value, max_input):
    """Computes a sum at `tff.SERVER` of a `value` placed on the `tff.CLIENTS`.

  This function computes a sum such that it should not be possible for the
  server to learn any clients individual value. The specific algorithm and
  mechanism used to compute the secure sum may vary depending on the target
  runtime environment the computation is compiled for or executed on. See
  https://research.google/pubs/pub47246/ for more information.

  Not all executors support `tff.federated_secure_sum()`; consult the
  documentation for the specific executor or executor stack you plan on using
  for the specific of how it's handled by that executor.

  The `max_input` argument is the maximum value (inclusive) that may appear in
  `value`. *Lower values may allow for improved communication efficiency.*
  Attempting to return a `value` higher than `max_input` is invalid, and will
  result in a failure at the given client.

  Example:

  ```python
  value = tff.federated_value(1, tff.CLIENTS)
  result = tff.federated_secure_sum(value, 1)

  value = tff.federated_value((1, 2), tff.CLIENTS)
  result = tff.federated_secure_sum(value, (1, 2))
  ```

  Note: To sum non-integer values or to sum integers with fewer constraints and
  weaker privacy properties, consider using `federated_sum`.

  Args:
    value: An integer or nested structure of integers placed at `tff.CLIENTS`,
      in the range `[0, max_input]`.
    max_input: A Python integer or nested structure of integers matching the
      structure of `value`. If integer `max_value` is used with a nested
      `value`, the same integer is used for each tensor in `value`.

  Returns:
    A representation of the sum of the member constituents of `value` placed
    on the `tff.SERVER`.

  Raises:
    TypeError: If the argument is not a federated TFF value placed at
      `tff.CLIENTS`.
  """
    value = value_impl.to_value(value, None)
    value = value_utils.ensure_federated_value(value, placements.CLIENTS,
                                               'value to be summed')
    type_analysis.check_is_structure_of_integers(value.type_signature)
    max_input_value = value_impl.to_value(max_input, None)
    value_member_type = value.type_signature.member
    max_input_type = max_input_value.type_signature
    if not type_analysis.is_single_integer_or_matches_structure(
            max_input_type, value_member_type):
        raise TypeError(
            'Expected `federated_secure_sum` parameter `max_input` to match '
            'the structure of `value`, with one integer max per tensor in '
            '`value`. Found `value` of `{}` and `max_input` of `{}`.'.format(
                value_member_type, max_input_type))
    if max_input_type.is_tensor() and value_member_type.is_struct():
        max_input_value = value_impl.to_value(
            structure.map_structure(lambda _: max_input, value_member_type),
            None)
    comp = building_block_factory.create_federated_secure_sum(
        value.comp, max_input_value.comp)
    comp = _bind_comp_as_reference(comp)
    return value_impl.Value(comp)