def federated_map(self, fn, arg): """Implements `federated_map` as defined in `api/intrinsics.py`. Args: fn: As in `api/intrinsics.py`. arg: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ # TODO(b/113112108): Possibly lift the restriction that the mapped value # must be placed at the server or clients. Would occur after adding support # for placement labels in the federated types, and expanding the type # specification of the intrinsic this is based on to work with federated # values of arbitrary placement. arg = value_impl.to_value(arg, None, self._context_stack) arg = value_utils.ensure_federated_value(arg, label='value to be mapped') # TODO(b/113112108): Add support for polymorphic templates auto-instantiated # here based on the actual type of the argument. fn = value_impl.to_value(fn, None, self._context_stack) py_typecheck.check_type(fn, value_base.Value) py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) if not type_utils.is_assignable_from(fn.type_signature.parameter, arg.type_signature.member): raise TypeError( 'The mapping function expects a parameter of type {}, but member ' 'constituents of the mapped value are of incompatible type {}.' .format(fn.type_signature.parameter, arg.type_signature.member)) # TODO(b/144384398): Change structure to one that maps the placement type # to the building_block function that fits it, in a way that allows the # appropriate type checks. if arg.type_signature.placement is placements.SERVER: if not arg.type_signature.all_equal: raise TypeError( 'Arguments placed at {} should be equal at all locations.'. format(placements.SERVER)) fn = value_impl.ValueImpl.get_comp(fn) arg = value_impl.ValueImpl.get_comp(arg) comp = building_block_factory.create_federated_apply(fn, arg) elif arg.type_signature.placement is placements.CLIENTS: fn = value_impl.ValueImpl.get_comp(fn) arg = value_impl.ValueImpl.get_comp(arg) comp = building_block_factory.create_federated_map(fn, arg) else: raise TypeError( 'The argument should be placed at {} or {}, placed at {} instead.' .format(placements.SERVER, placements.CLIENTS, arg.type_signature.placement)) return value_impl.ValueImpl(comp, self._context_stack)
def _create_complex_computation(): tensor_type = computation_types.TensorType(tf.int32) compiled = building_block_factory.create_compiled_identity( tensor_type, 'a') federated_type = computation_types.FederatedType(tf.int32, placements.SERVER) arg_ref = building_blocks.Reference('arg', federated_type) bindings = [] results = [] def _bind(name, value): bindings.append((name, value)) return building_blocks.Reference(name, value.type_signature) for i in range(2): called_federated_broadcast = building_block_factory.create_federated_broadcast( arg_ref) called_federated_map = building_block_factory.create_federated_map( compiled, _bind(f'broadcast_{i}', called_federated_broadcast)) called_federated_mean = building_block_factory.create_federated_mean( _bind(f'map_{i}', called_federated_map), None) results.append(_bind(f'mean_{i}', called_federated_mean)) result = building_blocks.Struct(results) block = building_blocks.Block(bindings, result) return building_blocks.Lambda('arg', tf.int32, block)
def test_cannot_split_on_chained_intrinsic(self): int_type = computation_types.TensorType(tf.int32) client_int_type = computation_types.at_clients(int_type) int_ref = lambda name: building_blocks.Reference(name, int_type) client_int_ref = ( lambda name: building_blocks.Reference(name, client_int_type)) body = building_blocks.Block([ ('a', building_block_factory.create_federated_map( building_blocks.Lambda('p1', int_type, int_ref('p1')), client_int_ref('param'))), ('b', building_block_factory.create_federated_map( building_blocks.Lambda('p2', int_type, int_ref('p2')), client_int_ref('a'))), ], client_int_ref('b')) comp = building_blocks.Lambda('param', int_type, body) with self.assertRaises(transformations._NonAlignableAlongIntrinsicError): transformations.force_align_and_split_by_intrinsics( comp, [building_block_factory.create_null_federated_map()])
def _create_chained_whimsy_federated_maps(functions, arg): py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) for fn in functions: py_typecheck.check_type(fn, building_blocks.ComputationBuildingBlock) if not fn.parameter_type.is_assignable_from(arg.type_signature.member): raise TypeError( 'The parameter of the function is of type {}, and the argument is of ' 'an incompatible type {}.'.format( str(fn.parameter_type), str(arg.type_signature.member))) call = building_block_factory.create_federated_map(fn, arg) arg = call return call
def _create_complex_computation(): compiled = building_block_factory.create_compiled_identity(tf.int32, 'a') federated_type = computation_types.FederatedType(tf.int32, placements.SERVER) ref = building_blocks.Reference('b', federated_type) called_federated_broadcast = building_block_factory.create_federated_broadcast( ref) called_federated_map = building_block_factory.create_federated_map( compiled, called_federated_broadcast) called_federated_mean = building_block_factory.create_federated_mean( called_federated_map, None) tup = building_blocks.Tuple([called_federated_mean, called_federated_mean]) return building_blocks.Lambda('b', tf.int32, tup)
def federated_map(self, fn, arg): """Implements `federated_map` as defined in `api/intrinsics.py`. Args: fn: As in `api/intrinsics.py`. arg: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ # TODO(b/113112108): Possibly lift the restriction that the mapped value # must be placed at the clients after adding support for placement labels # in the federated types, and expanding the type specification of the # intrinsic this is based on to work with federated values of arbitrary # placement. arg = value_impl.to_value(arg, None, self._context_stack) if isinstance(arg.type_signature, computation_types.NamedTupleType): if len(anonymous_tuple.to_elements(arg.type_signature)) >= 2: # We've been passed a value which the user expects to be zipped. arg = self.federated_zip(arg) value_utils.check_federated_value_placement(arg, placements.CLIENTS, 'value to be mapped') # TODO(b/113112108): Add support for polymorphic templates auto-instantiated # here based on the actual type of the argument. fn = value_impl.to_value(fn, None, self._context_stack) py_typecheck.check_type(fn, value_base.Value) py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) if not type_utils.is_assignable_from(fn.type_signature.parameter, arg.type_signature.member): raise TypeError( 'The mapping function expects a parameter of type {}, but member ' 'constituents of the mapped value are of incompatible type {}.' .format(fn.type_signature.parameter, arg.type_signature.member)) fn = value_impl.ValueImpl.get_comp(fn) arg = value_impl.ValueImpl.get_comp(arg) comp = building_block_factory.create_federated_map(fn, arg) return value_impl.ValueImpl(comp, self._context_stack)
def test_removes_federated_map_with_named_result(self): parameter_type = [('a', tf.int32), ('b', tf.int32)] fn = building_block_test_utils.create_identity_function( 'c', parameter_type) arg_type = computation_types.FederatedType(parameter_type, placements.CLIENTS) arg = building_blocks.Data('data', arg_type) call = building_block_factory.create_federated_map(fn, arg) comp = call transformed_comp, modified = tree_transformations.remove_mapped_or_applied_identity( comp) self.assertEqual(comp.compact_representation(), 'federated_map(<(c -> c),data>)') self.assertEqual(transformed_comp.compact_representation(), 'data') self.assertEqual(transformed_comp.type_signature, comp.type_signature) self.assertTrue(modified)
def create_dummy_called_federated_map(parameter_name, parameter_type=tf.int32): r"""Returns a dummy called federated map. Call / \ federated_map Tuple | [Lambda(x), data] | Ref(x) Args: parameter_name: The name of the parameter. parameter_type: The type of the parameter. """ fn = create_identity_function(parameter_name, parameter_type) arg_type = computation_types.FederatedType(parameter_type, placements.CLIENTS) arg = building_blocks.Data('data', arg_type) return building_block_factory.create_federated_map(fn, arg)
def federated_map(fn, arg): """Maps a federated value pointwise using a mapping function. The function `fn` is applied separately across the group of devices represented by the placement type of `arg`. For example, if `value` has placement type `tff.CLIENTS`, then `fn` is applied to each client individually. In particular, this operation does not alter the placement of the federated value. Args: fn: A mapping function to apply pointwise to member constituents of `arg`. The parameter of this function must be of the same type as the member constituents of `arg`. arg: A value of a TFF federated type (or a value that can be implicitly converted into a TFF federated type, e.g., by zipping) placed at `tff.CLIENTS` or `tff.SERVER`. Returns: A federated value with the same placement as `arg` that represents the result of `fn` on the member constituent of `arg`. Raises: TypeError: If the arguments are not of the appropriate types. """ # TODO(b/113112108): Possibly lift the restriction that the mapped value # must be placed at the server or clients. Would occur after adding support # for placement labels in the federated types, and expanding the type # specification of the intrinsic this is based on to work with federated # values of arbitrary placement. arg = value_impl.to_value(arg, None) arg = value_utils.ensure_federated_value(arg, label='value to be mapped') fn = value_impl.to_value(fn, None, parameter_type_hint=arg.type_signature.member) py_typecheck.check_type(fn, value_impl.Value) py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) if not fn.type_signature.parameter.is_assignable_from( arg.type_signature.member): raise TypeError( 'The mapping function expects a parameter of type {}, but member ' 'constituents of the mapped value are of incompatible type {}.'. format(fn.type_signature.parameter, arg.type_signature.member)) # TODO(b/144384398): Change structure to one that maps the placement type # to the building_block function that fits it, in a way that allows the # appropriate type checks. if arg.type_signature.placement is placements.SERVER: if not arg.type_signature.all_equal: raise TypeError( 'Arguments placed at {} should be equal at all locations.'. format(placements.SERVER)) comp = building_block_factory.create_federated_apply(fn.comp, arg.comp) elif arg.type_signature.placement is placements.CLIENTS: comp = building_block_factory.create_federated_map(fn.comp, arg.comp) else: raise TypeError( 'Expected `arg` to have a type with a supported placement, ' 'found {}.'.format(arg.type_signature.placement)) comp = _bind_comp_as_reference(comp) return value_impl.Value(comp)