def federated_broadcast(self, value): """Implements `federated_broadcast` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) type_utils.check_federated_value_placement(value, placements.SERVER, 'value to be broadcasted') if not value.type_signature.all_equal: raise TypeError('The broadcasted value should be equal at all locations.') # TODO(b/113112108): Replace this hand-crafted logic here and below with # a call to a helper function that handles it in a uniform manner after # implementing support for correctly typechecking federated template types # and instantiating template types on concrete arguments. result_type = computation_types.FederatedType(value.type_signature.member, placements.CLIENTS, True) intrinsic = value_impl.ValueImpl( computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_BROADCAST.uri, computation_types.FunctionType(value.type_signature, result_type)), self._context_stack) return intrinsic(value)
def test_getattr_call_named(self, placement): federated_comp_named = computation_building_blocks.Reference( 'test', computation_types.FederatedType([('a', tf.int32), ('b', tf.bool), tf.int32], placement, True)) self.assertEqual(str(federated_comp_named.type_signature.member), '<a=int32,b=bool,int32>') name_a = computation_constructing_utils.construct_federated_getattr_call( federated_comp_named, 'a') name_b = computation_constructing_utils.construct_federated_getattr_call( federated_comp_named, 'b') self.assertIsInstance(name_a.type_signature, computation_types.FederatedType) self.assertIsInstance(name_b.type_signature, computation_types.FederatedType) self.assertEqual(str(name_a.type_signature.member), 'int32') self.assertEqual(str(name_b.type_signature.member), 'bool') type_utils.check_federated_value_placement( value_impl.to_value(name_a, None, context_stack_impl.context_stack), placement) type_utils.check_federated_value_placement( value_impl.to_value(name_b, None, context_stack_impl.context_stack), placement) with self.assertRaisesRegex(ValueError, 'has no element of name c'): _ = computation_constructing_utils.construct_federated_getattr_call( federated_comp_named, 'c')
def federated_sum(self, value): """Implements `federated_sum` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) type_utils.check_federated_value_placement(value, placements.CLIENTS, 'value to be summed') if not type_utils.is_sum_compatible(value.type_signature): raise TypeError( 'The value type {} is not compatible with the sum operator.'.format( str(value.type_signature))) # TODO(b/113112108): Replace this as noted above. result_type = computation_types.FederatedType(value.type_signature.member, placements.SERVER, True) intrinsic = value_impl.ValueImpl( computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_SUM.uri, computation_types.FunctionType(value.type_signature, result_type)), self._context_stack) return intrinsic(value)
def federated_collect(self, value): """Implements `federated_collect` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) type_utils.check_federated_value_placement(value, placements.CLIENTS, 'value to be collected') result_type = computation_types.FederatedType( computation_types.SequenceType(value.type_signature.member), placements.SERVER, True) intrinsic = value_impl.ValueImpl( computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_COLLECT.uri, computation_types.FunctionType(value.type_signature, result_type)), self._context_stack) return intrinsic(value)
def federated_aggregate(self, value, zero, accumulate, merge, report): """Implements `federated_aggregate` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. zero: As in `api/intrinsics.py`. accumulate: As in `api/intrinsics.py`. merge: As in `api/intrinsics.py`. report: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) type_utils.check_federated_value_placement(value, placements.CLIENTS, 'value to be aggregated') zero = value_impl.to_value(zero, None, self._context_stack) py_typecheck.check_type(zero, value_base.Value) # TODO(b/113112108): We need a check here that zero does not have federated # constituents. accumulate = value_impl.to_value(accumulate, None, self._context_stack) merge = value_impl.to_value(merge, None, self._context_stack) report = value_impl.to_value(report, None, self._context_stack) for op in [accumulate, merge, report]: py_typecheck.check_type(op, value_base.Value) py_typecheck.check_type(op.type_signature, computation_types.FunctionType) accumulate_type_expected = type_constructors.reduction_op( zero.type_signature, value.type_signature.member) merge_type_expected = type_constructors.reduction_op( zero.type_signature, zero.type_signature) report_type_expected = computation_types.FunctionType( zero.type_signature, report.type_signature.result) for op_name, op, type_expected in [ ('accumulate', accumulate, accumulate_type_expected), ('merge', merge, merge_type_expected), ('report', report, report_type_expected) ]: if not type_utils.is_assignable_from(type_expected, op.type_signature): raise TypeError( 'Expected parameter `{}` to be of type {}, but received {} instead.' .format(op_name, type_expected, op.type_signature)) value = value_impl.ValueImpl.get_comp(value) zero = value_impl.ValueImpl.get_comp(zero) accumulate = value_impl.ValueImpl.get_comp(accumulate) merge = value_impl.ValueImpl.get_comp(merge) report = value_impl.ValueImpl.get_comp(report) comp = computation_constructing_utils.create_federated_aggregate( value, zero, accumulate, merge, report) return value_impl.ValueImpl(comp, self._context_stack)
def zip_two_tuple(input_val, context_stack): """Helper function to perform 2-tuple at a time zipping. Takes 2-tuple of federated values and returns federated 2-tuple of values. Args: input_val: 2-tuple TFF `Value` of `NamedTuple` type, whose elements must be `FederatedTypes` with the same placement. context_stack: The context stack to use, as in `impl.value_impl.to_value`. Returns: TFF `Value` of `FederatedType` with member of 2-tuple `NamedTuple` type. """ py_typecheck.check_type(input_val, value_base.Value) py_typecheck.check_type(input_val.type_signature, computation_types.NamedTupleType) py_typecheck.check_type(input_val[0].type_signature, computation_types.FederatedType) zip_uris = { placements.CLIENTS: intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS.uri, placements.SERVER: intrinsic_defs.FEDERATED_ZIP_AT_SERVER.uri, } zip_all_equal = { placements.CLIENTS: False, placements.SERVER: True, } output_placement = input_val[0].type_signature.placement if output_placement not in zip_uris: raise TypeError( 'The argument must have components placed at SERVER or ' 'CLIENTS') output_all_equal_bit = zip_all_equal[output_placement] for elem in input_val: type_utils.check_federated_value_placement(elem, output_placement) num_elements = len(anonymous_tuple.to_elements(input_val.type_signature)) if num_elements != 2: raise ValueError('The argument of zip_two_tuple must be a 2-tuple, ' 'not an {}-tuple'.format(num_elements)) result_type = computation_types.FederatedType( [(name, e.member) for name, e in anonymous_tuple.to_elements(input_val.type_signature)], output_placement, output_all_equal_bit) def _adjust_all_equal_bit(x): return computation_types.FederatedType(x.member, x.placement, output_all_equal_bit) adjusted_input_type = computation_types.NamedTupleType([ (k, _adjust_all_equal_bit(v)) if k else _adjust_all_equal_bit(v) for k, v in anonymous_tuple.to_elements(input_val.type_signature) ]) intrinsic = value_impl.ValueImpl( computation_building_blocks.Intrinsic( zip_uris[output_placement], computation_types.FunctionType(adjusted_input_type, result_type)), context_stack) return intrinsic(input_val)
def federated_map(self, mapping_fn, value): """Implements `federated_map` as defined in `api/intrinsics.py`. Args: mapping_fn: As in `api/intrinsics.py`. value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ # TODO(b/113112108): Possibly lift the restriction that the mapped value # must be placed at the clients after adding support for placement labels # in the federated types, and expanding the type specification of the # intrinsic this is based on to work with federated values of arbitrary # placement. value = value_impl.to_value(value, None, self._context_stack) if isinstance(value.type_signature, computation_types.NamedTupleType): if len(anonymous_tuple.to_elements(value.type_signature)) >= 2: # We've been passed a value which the user expects to be zipped. value = self.federated_zip(value) type_utils.check_federated_value_placement(value, placements.CLIENTS, 'value to be mapped') # TODO(b/113112108): Add support for polymorphic templates auto-instantiated # here based on the actual type of the argument. mapping_fn = value_impl.to_value(mapping_fn, None, self._context_stack) py_typecheck.check_type(mapping_fn, value_base.Value) py_typecheck.check_type(mapping_fn.type_signature, computation_types.FunctionType) if not type_utils.is_assignable_from(mapping_fn.type_signature.parameter, value.type_signature.member): raise TypeError( 'The mapping function expects a parameter of type {}, but member ' 'constituents of the mapped value are of incompatible type {}.' .format( str(mapping_fn.type_signature.parameter), str(value.type_signature.member))) # TODO(b/113112108): Replace this as noted above. result_type = computation_types.FederatedType( mapping_fn.type_signature.result, placements.CLIENTS, value.type_signature.all_equal) intrinsic = value_impl.ValueImpl( computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, computation_types.FunctionType( [mapping_fn.type_signature, value.type_signature], result_type)), self._context_stack) return intrinsic(mapping_fn, value)
def federated_mean(self, value, weight): """Implements `federated_mean` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. weight: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ # TODO(b/113112108): Possibly relax the constraints on numeric types, and # inject implicit casts where appropriate. For instance, we might want to # allow `tf.int32` values as the input, and automatically cast them to # `tf.float321 before invoking the average, thus producing a floating-point # result. # TODO(b/120439632): Possibly allow the weight to be either structured or # non-scalar, e.g., for the case of averaging a convolutional layer, when # we would want to use a different weight for every filter, and where it # might be cumbersome for users to have to manually slice and assemble a # variable. value = value_impl.to_value(value, None, self._context_stack) type_utils.check_federated_value_placement(value, placements.CLIENTS, 'value to be averaged') if not type_utils.is_average_compatible(value.type_signature): raise TypeError( 'The value type {} is not compatible with the average operator.' .format(value.type_signature)) if weight is not None: weight = value_impl.to_value(weight, None, self._context_stack) type_utils.check_federated_value_placement( weight, placements.CLIENTS, 'weight to use in averaging') py_typecheck.check_type(weight.type_signature.member, computation_types.TensorType) if weight.type_signature.member.shape.ndims != 0: raise TypeError( 'The weight type {} is not a federated scalar.'.format( weight.type_signature)) if not (weight.type_signature.member.dtype.is_integer or weight.type_signature.member.dtype.is_floating): raise TypeError( 'The weight type {} is not a federated integer or floating-point ' 'tensor.'.format(weight.type_signature)) value = value_impl.ValueImpl.get_comp(value) if weight is not None: weight = value_impl.ValueImpl.get_comp(weight) comp = computation_constructing_utils.create_federated_mean( value, weight) return value_impl.ValueImpl(comp, self._context_stack)
def federated_reduce(self, value, zero, op): """Implements `federated_reduce` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. zero: As in `api/intrinsics.py`. op: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ # TODO(b/113112108): Since in most cases, it can be assumed that CLIENTS is # a non-empty collective (or else, the computation fails), specifying zero # at this level of the API should probably be optional. TBD. value = value_impl.to_value(value, None, self._context_stack) type_utils.check_federated_value_placement(value, placements.CLIENTS, 'value to be reduced') zero = value_impl.to_value(zero, None, self._context_stack) py_typecheck.check_type(zero, value_base.Value) # TODO(b/113112108): We need a check here that zero does not have federated # constituents. op = value_impl.to_value(op, None, self._context_stack) py_typecheck.check_type(op, value_base.Value) py_typecheck.check_type(op.type_signature, computation_types.FunctionType) op_type_expected = type_constructors.reduction_op( zero.type_signature, value.type_signature.member) if not type_utils.is_assignable_from(op_type_expected, op.type_signature): raise TypeError('Expected an operator of type {}, got {}.'.format( str(op_type_expected), str(op.type_signature))) # TODO(b/113112108): Replace this as noted above. result_type = computation_types.FederatedType(zero.type_signature, placements.SERVER, True) intrinsic = value_impl.ValueImpl( computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_REDUCE.uri, computation_types.FunctionType([ value.type_signature, zero.type_signature, op_type_expected ], result_type)), self._context_stack) return intrinsic(value, zero, op)
def federated_map_all_equal(self, fn, arg): """Implements `federated_map` as defined in `api/intrinsic.py`. Implements `federated_map` as defined in `api/intrinsic.py` with an argument with the `all_equal` bit set. Args: fn: As in `api/intrinsics.py`. arg: As in `api/intrinsics.py`, with the `all_equal` bit set. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ # TODO(b/113112108): Possibly lift the restriction that the mapped value # must be placed at the clients after adding support for placement labels # in the federated types, and expanding the type specification of the # intrinsic this is based on to work with federated values of arbitrary # placement. arg = value_impl.to_value(arg, None, self._context_stack) if isinstance(arg.type_signature, computation_types.NamedTupleType): if len(anonymous_tuple.to_elements(arg.type_signature)) >= 2: # We've been passed a value which the user expects to be zipped. arg = self.federated_zip(arg) type_utils.check_federated_value_placement(arg, placements.CLIENTS, 'value to be mapped') # TODO(b/113112108): Add support for polymorphic templates auto-instantiated # here based on the actual type of the argument. fn = value_impl.to_value(fn, None, self._context_stack) py_typecheck.check_type(fn, value_base.Value) py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) if not type_utils.is_assignable_from(fn.type_signature.parameter, arg.type_signature.member): raise TypeError( 'The mapping function expects a parameter of type {}, but member ' 'constituents of the mapped value are of incompatible type {}.' .format(fn.type_signature.parameter, arg.type_signature.member)) fn = value_impl.ValueImpl.get_comp(fn) arg = value_impl.ValueImpl.get_comp(arg) comp = computation_constructing_utils.create_federated_map_all_equal( fn, arg) return value_impl.ValueImpl(comp, self._context_stack)
def federated_apply(self, func, arg): """Implements `federated_apply` as defined in `api/intrinsics.py`. Args: func: As in `api/intrinsics.py`. arg: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ func = value_impl.to_value(func, None, self._context_stack) py_typecheck.check_type(func, value_base.Value) py_typecheck.check_type(func.type_signature, computation_types.FunctionType) arg = value_impl.to_value(arg, None, self._context_stack) if isinstance(arg.type_signature, computation_types.NamedTupleType): if len(anonymous_tuple.to_elements(arg.type_signature)) >= 2: # We've been passed a value which the user expects to be zipped. arg = self.federated_zip(arg) type_utils.check_federated_value_placement(arg, placements.SERVER, 'the argument') if not arg.type_signature.all_equal: raise TypeError('The argument should be equal at all locations.') if not type_utils.is_assignable_from(func.type_signature.parameter, arg.type_signature.member): raise TypeError( 'The function to apply expects a parameter of type {}, but member ' 'constituents of the argument are of an incompatible type {}.'. format(str(func.type_signature.parameter), str(arg.type_signature.member))) # TODO(b/113112108): Replace this as noted in `federated_broadcast()`. result_type = computation_types.FederatedType( func.type_signature.result, placements.SERVER, True) intrinsic = value_impl.ValueImpl( computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_APPLY.uri, computation_types.FunctionType( [func.type_signature, arg.type_signature], result_type)), self._context_stack) return intrinsic(func, arg)
def federated_collect(self, value): """Implements `federated_collect` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) type_utils.check_federated_value_placement(value, placements.CLIENTS, 'value to be collected') value = value_impl.ValueImpl.get_comp(value) comp = computation_constructing_utils.create_federated_collect(value) return value_impl.ValueImpl(comp, self._context_stack)
def test_getitem_call_unnamed(self, placement): federated_comp_unnamed = computation_building_blocks.Reference( 'test', computation_types.FederatedType([tf.int32, tf.bool], placement, True)) self.assertEqual(str(federated_comp_unnamed.type_signature.member), '<int32,bool>') unnamed_idx_0 = computation_constructing_utils.construct_federated_getitem_call( federated_comp_unnamed, 0) unnamed_idx_1 = computation_constructing_utils.construct_federated_getitem_call( federated_comp_unnamed, 1) self.assertIsInstance(unnamed_idx_0.type_signature, computation_types.FederatedType) self.assertIsInstance(unnamed_idx_1.type_signature, computation_types.FederatedType) self.assertEqual(str(unnamed_idx_0.type_signature.member), 'int32') self.assertEqual(str(unnamed_idx_1.type_signature.member), 'bool') type_utils.check_federated_value_placement( value_impl.to_value(unnamed_idx_0, None, context_stack_impl.context_stack), placement) type_utils.check_federated_value_placement( value_impl.to_value(unnamed_idx_1, None, context_stack_impl.context_stack), placement) unnamed_flipped = computation_constructing_utils.construct_federated_getitem_call( federated_comp_unnamed, slice(None, None, -1)) self.assertIsInstance(unnamed_flipped.type_signature, computation_types.FederatedType) self.assertEqual(str(unnamed_flipped.type_signature.member), '<bool,int32>') type_utils.check_federated_value_placement( value_impl.to_value(unnamed_flipped, None, context_stack_impl.context_stack), placement)
def federated_apply(self, fn, arg): """Implements `federated_apply` as defined in `api/intrinsics.py`. Args: fn: As in `api/intrinsics.py`. arg: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ fn = value_impl.to_value(fn, None, self._context_stack) py_typecheck.check_type(fn, value_base.Value) py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) arg = value_impl.to_value(arg, None, self._context_stack) if isinstance(arg.type_signature, computation_types.NamedTupleType): if len(anonymous_tuple.to_elements(arg.type_signature)) >= 2: # We've been passed a value which the user expects to be zipped. arg = self.federated_zip(arg) type_utils.check_federated_value_placement(arg, placements.SERVER, 'the argument') if not arg.type_signature.all_equal: raise TypeError('The argument should be equal at all locations.') if not type_utils.is_assignable_from(fn.type_signature.parameter, arg.type_signature.member): raise TypeError( 'The function to apply expects a parameter of type {}, but member ' 'constituents of the argument are of an incompatible type {}.'. format(fn.type_signature.parameter, arg.type_signature.member)) fn = value_impl.ValueImpl.get_comp(fn) arg = value_impl.ValueImpl.get_comp(arg) comp = computation_constructing_utils.create_federated_apply(fn, arg) return value_impl.ValueImpl(comp, self._context_stack)
def federated_broadcast(self, value): """Implements `federated_broadcast` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) type_utils.check_federated_value_placement(value, placements.SERVER, 'value to be broadcasted') if not value.type_signature.all_equal: raise TypeError('The broadcasted value should be equal at all locations.') value = value_impl.ValueImpl.get_comp(value) comp = computation_constructing_utils.create_federated_broadcast(value) return value_impl.ValueImpl(comp, self._context_stack)
def federated_sum(self, value): """Implements `federated_sum` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) type_utils.check_federated_value_placement(value, placements.CLIENTS, 'value to be summed') if not type_utils.is_sum_compatible(value.type_signature): raise TypeError( 'The value type {} is not compatible with the sum operator.'.format( value.type_signature)) value = value_impl.ValueImpl.get_comp(value) comp = computation_constructing_utils.create_federated_sum(value) return value_impl.ValueImpl(comp, self._context_stack)
def _(x): type_utils.check_federated_value_placement(x, placements.CLIENTS) with self.assertRaises(TypeError): type_utils.check_federated_value_placement( x, placements.SERVER) return x