async def _map(self, arg, all_equal=None): self._check_arg_is_structure(arg) py_typecheck.check_len(arg.internal_representation, 2) fn_type = arg.type_signature[0] py_typecheck.check_type(fn_type, computation_types.FunctionType) val_type = arg.type_signature[1] py_typecheck.check_type(val_type, computation_types.FederatedType) if all_equal is None: all_equal = val_type.all_equal elif all_equal and not val_type.all_equal: raise ValueError( 'Cannot map a non-all_equal argument into an all_equal result.' ) fn = arg.internal_representation[0] py_typecheck.check_type(fn, pb.Computation) val = arg.internal_representation[1] py_typecheck.check_type(val, list) for v in val: py_typecheck.check_type(v, executor_value_base.ExecutorValue) self._check_strategy_compatible_with_placement(val_type.placement) children = self._target_executors[val_type.placement] async def _map_child(fn, fn_type, value, child): fn_at_child = await child.create_value(fn, fn_type) return await child.create_call(fn_at_child, value) results = await asyncio.gather(*[ _map_child(fn, fn_type, value, child) for (value, child) in zip(val, children) ]) return FederatedResolvingStrategyValue( results, computation_types.FederatedType(fn_type.result, val_type.placement, all_equal=all_equal))
async def compute_federated_aggregate( self, arg: FederatedComposingStrategyValue ) -> FederatedComposingStrategyValue: value_type, zero_type, accumulate_type, merge_type, report_type = ( executor_utils.parse_federated_aggregate_argument_types( arg.type_signature)) py_typecheck.check_type(arg.internal_representation, structure.Struct) py_typecheck.check_len(arg.internal_representation, 5) val = arg.internal_representation[0] py_typecheck.check_type(val, list) py_typecheck.check_len(val, len(self._target_executors)) identity_report, identity_report_type = tensorflow_computation_factory.create_identity( zero_type) aggr_type = computation_types.FunctionType( computation_types.StructType([ value_type, zero_type, accumulate_type, merge_type, identity_report_type ]), computation_types.at_server(zero_type)) aggr_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_AGGREGATE, aggr_type) zero = await (await self._executor.create_selection(arg, index=1)).compute() accumulate = arg.internal_representation[2] merge = arg.internal_representation[3] report = arg.internal_representation[4] async def _child_fn(ex, v): py_typecheck.check_type(v, executor_value_base.ExecutorValue) arg_values = [ ex.create_value(zero, zero_type), ex.create_value(accumulate, accumulate_type), ex.create_value(merge, merge_type), ex.create_value(identity_report, identity_report_type) ] aggr_func, aggr_args = await asyncio.gather( ex.create_value(aggr_comp, aggr_type), ex.create_struct([v] + list(await asyncio.gather(*arg_values)))) child_result = await (await ex.create_call(aggr_func, aggr_args)).compute() result_at_server = await self._server_executor.create_value( child_result, zero_type) return result_at_server val_futures = asyncio.as_completed( [_child_fn(c, v) for c, v in zip(self._target_executors, val)]) parent_merge, parent_report = await asyncio.gather( self._server_executor.create_value(merge, merge_type), self._server_executor.create_value(report, report_type)) merge_result = await next(val_futures) for next_val_future in val_futures: next_val = await next_val_future merge_arg = await self._server_executor.create_struct( [merge_result, next_val]) merge_result = await self._server_executor.create_call( parent_merge, merge_arg) report_result = await self._server_executor.create_call( parent_report, merge_result) return FederatedComposingStrategyValue( report_result, computation_types.at_server(report_type.result))
def _check_iterative_process_compatible_with_canonical_form( initialize_tree, next_tree): """Tests compatibility with `tff.backends.mapreduce.CanonicalForm`. Args: initialize_tree: An instance of `building_blocks.ComputationBuildingBlock` that maps to the `initalize` property of a `tff.utils.IterativeProcess`. next_tree: An instance of `building_blocks.ComputationBuildingBlock` that maps to the `next` property of a `tff.utils.IterativeProcess`. Raises: TypeError: If the arguments are of the wrong types. """ py_typecheck.check_type(initialize_tree, building_blocks.ComputationBuildingBlock) py_typecheck.check_type(initialize_tree.type_signature, computation_types.FederatedType) py_typecheck.check_type(next_tree, building_blocks.ComputationBuildingBlock) py_typecheck.check_type(next_tree.type_signature, computation_types.FunctionType) py_typecheck.check_type(next_tree.type_signature.parameter, computation_types.NamedTupleType) py_typecheck.check_len(next_tree.type_signature.parameter, 2) py_typecheck.check_type(next_tree.type_signature.result, computation_types.NamedTupleType) py_typecheck.check_len(next_tree.type_signature.parameter, 2) next_result_len = len(next_tree.type_signature.result) if next_result_len != 2 and next_result_len != 3: raise TypeError( 'Expected length of 2 or 3, found {}.'.format(next_result_len))
async def _map(self, arg, all_equal=None): py_typecheck.check_type(arg.internal_representation, structure.Struct) py_typecheck.check_len(arg.internal_representation, 2) fn_type = arg.type_signature[0] py_typecheck.check_type(fn_type, computation_types.FunctionType) val_type = arg.type_signature[1] py_typecheck.check_type(val_type, computation_types.FederatedType) if all_equal is None: all_equal = val_type.all_equal elif all_equal and not val_type.all_equal: raise ValueError( 'Cannot map a non-all_equal argument into an all_equal result.') fn = arg.internal_representation[0] py_typecheck.check_type(fn, pb.Computation) val = arg.internal_representation[1] py_typecheck.check_type(val, list) map_type = computation_types.FunctionType( [fn_type, computation_types.at_clients(fn_type.parameter)], computation_types.at_clients(fn_type.result)) map_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_MAP, map_type) async def _child_fn(ex, v): py_typecheck.check_type(v, executor_value_base.ExecutorValue) fn_val = await ex.create_value(fn, fn_type) map_val, map_arg = await asyncio.gather( ex.create_value(map_comp, map_type), ex.create_struct([fn_val, v])) return await ex.create_call(map_val, map_arg) result_vals = await asyncio.gather( *[_child_fn(c, v) for c, v in zip(self._target_executors, val)]) federated_type = computation_types.FederatedType( fn_type.result, val_type.placement, all_equal=all_equal) return FederatedComposingStrategyValue(result_vals, federated_type)
def _check_iterative_process_compatible_with_canonical_form( initialize_tree, next_tree): """Tests compatibility with `tff.backends.mapreduce.CanonicalForm`. Args: initialize_tree: An instance of `building_blocks.ComputationBuildingBlock` representing the `initalize` component of an `tff.templates.IterativeProcess`. next_tree: An instance of `building_blocks.ComputationBuildingBlock` that representing `next` component of an `tff.templates.IterativeProcess`. Raises: TypeError: If the arguments are of the wrong types. """ py_typecheck.check_type(initialize_tree, building_blocks.ComputationBuildingBlock) init_tree_ty = initialize_tree.type_signature _check_type_is_no_arg_fn(init_tree_ty, TypeError) _check_type(init_tree_ty.result, computation_types.FederatedType, TypeError) _check_placement(init_tree_ty.result, placements.SERVER, TypeError) py_typecheck.check_type(next_tree, building_blocks.ComputationBuildingBlock) py_typecheck.check_type(next_tree.type_signature, computation_types.FunctionType) py_typecheck.check_type(next_tree.type_signature.parameter, computation_types.StructType) py_typecheck.check_len(next_tree.type_signature.parameter, 2) py_typecheck.check_type(next_tree.type_signature.result, computation_types.StructType) py_typecheck.check_len(next_tree.type_signature.parameter, 2) next_result_len = len(next_tree.type_signature.result) if next_result_len != 2: raise TypeError('Expected length of 2, found {}.'.format(next_result_len))
async def compute_federated_secure_sum( self, arg: federated_resolving_strategy.FederatedResolvingStrategyValue ) -> federated_resolving_strategy.FederatedResolvingStrategyValue: py_typecheck.check_type(arg.internal_representation, structure.Struct) py_typecheck.check_len(arg.internal_representation, 2) logging.warning( 'The implementation of the `tff.federated_secure_sum` intrinsic ' 'provided by the `tff.backends.test` runtime uses no cryptography.') server_ex = self._target_executors[placement_literals.SERVER][0] value = federated_resolving_strategy.FederatedResolvingStrategyValue( arg.internal_representation[0], arg.type_signature[0]) bitwidth = await arg.internal_representation[1].compute() bitwidth_type = arg.type_signature[1] sum_result, mask, fn = await asyncio.gather( self.compute_federated_sum(value), executor_utils.embed_tf_constant(self._executor, bitwidth_type, 2**bitwidth - 1), executor_utils.embed_tf_binary_operator(self._executor, bitwidth_type, tf.math.mod)) fn_arg = await server_ex.create_struct([ sum_result.internal_representation[0], mask.internal_representation, ]) fn_arg_type = computation_types.FederatedType( fn_arg.type_signature, placement_literals.SERVER, all_equal=True) arg = federated_resolving_strategy.FederatedResolvingStrategyValue( structure.Struct([ (None, fn.internal_representation), (None, [fn_arg]), ]), computation_types.StructType([fn.type_signature, fn_arg_type])) return await self.compute_federated_map_all_equal(arg)
async def compute_federated_aggregate( self, arg: FederatedResolvingStrategyValue ) -> FederatedResolvingStrategyValue: val_type, zero_type, accumulate_type, merge_type, report_type = ( executor_utils.parse_federated_aggregate_argument_types( arg.type_signature)) del val_type, zero_type, merge_type py_typecheck.check_type(arg.internal_representation, structure.Struct) py_typecheck.check_len(arg.internal_representation, 5) val, zero, accumulate, merge, report = arg.internal_representation # Discard `merge`. Since all aggregation happens on a single executor, # there's no need for this additional layer. del merge pre_report = await self.reduce(val, zero, accumulate, accumulate_type) py_typecheck.check_type(pre_report.type_signature, computation_types.FederatedType) pre_report.type_signature.member.check_equivalent_to( report_type.parameter) return await self.compute_federated_apply( FederatedResolvingStrategyValue( structure.Struct([(None, report), (None, pre_report.internal_representation)]), computation_types.StructType( (report_type, pre_report.type_signature))))
async def _map(self, arg, all_equal=None): self._check_arg_is_anonymous_tuple(arg) py_typecheck.check_len(arg.internal_representation, 2) fn_type = arg.type_signature[0] py_typecheck.check_type(fn_type, computation_types.FunctionType) val_type = arg.type_signature[1] py_typecheck.check_type(val_type, computation_types.FederatedType) if all_equal is None: all_equal = val_type.all_equal elif all_equal and not val_type.all_equal: raise ValueError( 'Cannot map a non-all_equal argument into an all_equal result.') fn = arg.internal_representation[0] py_typecheck.check_type(fn, pb.Computation) val = arg.internal_representation[1] py_typecheck.check_type(val, list) for v in val: py_typecheck.check_type(v, executor_value_base.ExecutorValue) children = self._target_executors[val_type.placement] fns = await asyncio.gather(*[c.create_value(fn, fn_type) for c in children]) results = await asyncio.gather(*[ c.create_call(f, v) for c, (f, v) in zip(children, list(zip(fns, val))) ]) return FederatingExecutorValue( results, computation_types.FederatedType( fn_type.result, val_type.placement, all_equal=all_equal))
async def compute_federated_aggregate( self, arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue: val_type, zero_type, accumulate_type, merge_type, report_type = ( executor_utils.parse_federated_aggregate_argument_types( arg.type_signature)) del val_type, merge_type py_typecheck.check_type(arg.internal_representation, structure.Struct) py_typecheck.check_len(arg.internal_representation, 5) val, zero, accumulate, merge, report = arg.internal_representation # Discard `merge`. Since all aggregation happens on a single executor, # there's no need for this additional layer. del merge # Re-wrap `zero` in a `FederatingResolvingStrategyValue` to ensure that it # is an `ExecutorValue` rather than a `Struct` (since the internal # representation can include embedded values, lists of embedded values # (in the case of federated values), or `Struct`s. zero = FederatedResolvingStrategyValue(zero, zero_type) pre_report = await self.reduce(val, zero, accumulate, accumulate_type) py_typecheck.check_type(pre_report.type_signature, computation_types.FederatedType) pre_report.type_signature.member.check_equivalent_to(report_type.parameter) return await self.compute_federated_apply( FederatedResolvingStrategyValue( structure.Struct([(None, report), (None, pre_report.internal_representation)]), computation_types.StructType( (report_type, pre_report.type_signature))))
async def _compute_intrinsic_federated_map(self, arg): py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 2) fn_type = arg.type_signature[0] py_typecheck.check_type(fn_type, computation_types.FunctionType) val_type = arg.type_signature[1] type_utils.check_federated_type(val_type, fn_type.parameter, placement_literals.CLIENTS) fn = arg.internal_representation[0] val = arg.internal_representation[1] py_typecheck.check_type(fn, pb.Computation) py_typecheck.check_type(val, list) map_type = computation_types.FunctionType( [fn_type, type_factory.at_clients(fn_type.parameter)], type_factory.at_clients(fn_type.result)) map_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_MAP, map_type) async def _child_fn(ex, v): py_typecheck.check_type(v, executor_value_base.ExecutorValue) fn_val = await ex.create_value(fn, fn_type) map_val, map_arg = tuple(await asyncio.gather( ex.create_value(map_comp, map_type), ex.create_tuple([fn_val, v]))) return await ex.create_call(map_val, map_arg) result_vals = await asyncio.gather( *[_child_fn(c, v) for c, v in zip(self._child_executors, val)]) return CompositeValue(result_vals, type_factory.at_clients(fn_type.result))
def parse_federated_aggregate_argument_types(type_spec): """Verifies and parses `type_spec` into constituents. Args: type_spec: An instance of `computation_types.StructType`. Returns: A tuple of (value_type, zero_type, accumulate_type, merge_type, report_type) for the 5 type constituents. """ py_typecheck.check_type(type_spec, computation_types.StructType) py_typecheck.check_len(type_spec, 5) value_type = type_spec[0] py_typecheck.check_type(value_type, computation_types.FederatedType) item_type = value_type.member zero_type = type_spec[1] accumulate_type = type_spec[2] accumulate_type.check_equivalent_to( type_factory.reduction_op(zero_type, item_type)) merge_type = type_spec[3] merge_type.check_equivalent_to(type_factory.binary_op(zero_type)) report_type = type_spec[4] py_typecheck.check_type(report_type, computation_types.FunctionType) report_type.parameter.check_equivalent_to(zero_type) return value_type, zero_type, accumulate_type, merge_type, report_type
async def compute_federated_value( self, value: Any, type_signature: computation_types.Type ) -> FederatedComposingStrategyValue: if type_signature.placement == placement_literals.SERVER: if not type_signature.all_equal: raise ValueError( 'Expected an all equal value at the `SERVER` placement, ' 'found {}.'.format(type_signature)) results = await self._server_executor.create_value( value, type_signature.member) return FederatedComposingStrategyValue(results, type_signature) elif type_signature.placement == placement_literals.CLIENTS: if type_signature.all_equal: results = await asyncio.gather(*[ c.create_value(value, type_signature) for c in self._target_executors ]) return FederatedComposingStrategyValue(results, type_signature) else: py_typecheck.check_type(value, list) cardinalities = await self._get_cardinalities() total_clients = sum(cardinalities) py_typecheck.check_len(value, total_clients) results = [] offset = 0 for child, num_clients in zip(self._target_executors, cardinalities): new_offset = offset + num_clients result = child.create_value(value[offset:new_offset], type_signature) results.append(result) offset = new_offset return FederatedComposingStrategyValue(await asyncio.gather(*results), type_signature) else: raise ValueError('Unexpected placement {}.'.format( type_signature.placement))
async def _encrypt_values_on_clients(self, val, sender, receiver): ### # Case 2: sender=CLIENTS # plaintext: Fed(Tensor, CLIENTS, all_equal=False) # pk_receiver: Fed(Tensor, CLIENTS, all_equal=True) # sk_sender: Fed(Tensor, CLIENTS, all_equal=False) # Returns: # encrypted_values: Fed(Tensor, CLIENTS, all_equal=False) ### ### Check proper key placement sk_sender = self.key_references.get_secret_key(sender) pk_receiver = self.key_references.get_public_key(receiver) type_analysis.check_federated_type(sk_sender.type_signature, placement=sender) assert sk_sender.type_signature.placement is sender assert pk_receiver.type_signature.placement is sender ### Check placement cardinalities snd_children = self.strategy._get_child_executors(sender) rcv_children = self.strategy._get_child_executors(receiver) py_typecheck.check_len(rcv_children, 1) ### Check value cardinalities type_analysis.check_federated_type(val.type_signature, placement=sender) federated_value_internals = [ val.internal_representation, pk_receiver.internal_representation, sk_sender.internal_representation ] for v in federated_value_internals: py_typecheck.check_len(v, len(snd_children)) ### Materialize encryptor function definition & type spec input_type = val.type_signature.member self._input_type_cache = input_type pk_rcv_type = pk_receiver.type_signature.member sk_snd_type = sk_sender.type_signature.member pk_element_type = pk_rcv_type encryptor_arg_spec = (input_type, pk_element_type, sk_snd_type) encryptor_proto, encryptor_type = utils.materialize_computation_from_cache( sodium_comp.make_encryptor, self._encryptor_cache, encryptor_arg_spec) ### Encrypt values and return them encryptor_fns = asyncio.gather(*[ snd_child.create_value(encryptor_proto, encryptor_type) for snd_child in snd_children ]) encryptor_args = asyncio.gather(*[ snd_child.create_struct([v, pk, sk]) for v, pk, sk, snd_child in zip(*federated_value_internals, snd_children) ]) encryptor_fns, encryptor_args = await asyncio.gather( encryptor_fns, encryptor_args) encrypted_values = [ snd_child.create_call(encryptor, arg) for encryptor, arg, snd_child in zip(encryptor_fns, encryptor_args, snd_children) ] return federated_resolving_strategy.FederatedResolvingStrategyValue( await asyncio.gather(*encrypted_values), tff.FederatedType(encryptor_type.result, sender, all_equal=val.type_signature.all_equal))
async def compute_federated_aggregate( self, arg: FederatedComposingStrategyValue ) -> FederatedComposingStrategyValue: value_type, zero_type, accumulate_type, merge_type, report_type = ( executor_utils.parse_federated_aggregate_argument_types( arg.type_signature)) py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 5) val = arg.internal_representation[0] py_typecheck.check_type(val, list) py_typecheck.check_len(val, len(self._target_executors)) identity_report = tensorflow_computation_factory.create_identity( zero_type) identity_report_type = type_factory.unary_op(zero_type) aggr_type = computation_types.FunctionType( computation_types.NamedTupleType([ value_type, zero_type, accumulate_type, merge_type, identity_report_type ]), type_factory.at_server(zero_type)) aggr_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_AGGREGATE, aggr_type) zero = await (await self._executor.create_selection(arg, index=1)).compute() accumulate = arg.internal_representation[2] merge = arg.internal_representation[3] report = arg.internal_representation[4] async def _child_fn(ex, v): py_typecheck.check_type(v, executor_value_base.ExecutorValue) aggr_func, aggr_args = await asyncio.gather( ex.create_value(aggr_comp, aggr_type), ex.create_tuple([v] + list(await asyncio.gather( ex.create_value(zero, zero_type), ex.create_value(accumulate, accumulate_type), ex.create_value(merge, merge_type), ex.create_value(identity_report, identity_report_type))))) return await (await ex.create_call(aggr_func, aggr_args)).compute() vals = await asyncio.gather( *[_child_fn(c, v) for c, v in zip(self._target_executors, val)]) parent_vals = await asyncio.gather( *[self._server_executor.create_value(v, zero_type) for v in vals]) parent_merge, parent_report = await asyncio.gather( self._server_executor.create_value(merge, merge_type), self._server_executor.create_value(report, report_type)) merge_result = parent_vals[0] for next_val in parent_vals[1:]: merge_result = await self._server_executor.create_call( parent_merge, await self._server_executor.create_tuple([merge_result, next_val])) return FederatedComposingStrategyValue( await self._server_executor.create_call(parent_report, merge_result), type_factory.at_server(report_type.result))
async def _decrypt_values_on_clients(self, val, sender, receiver): ### Check proper key placement pk_sender = self.key_references.get_public_key(sender) sk_receiver = self.key_references.get_secret_key(receiver) type_analysis.check_federated_type(pk_sender.type_signature, placement=receiver) type_analysis.check_federated_type(sk_receiver.type_signature, placement=receiver) pk_snd_type = pk_sender.type_signature.member sk_rcv_type = sk_receiver.type_signature.member ### Check value cardinalities rcv_children = self.strategy._get_child_executors(receiver) federated_value_internals = [ val.internal_representation, pk_sender.internal_representation, sk_receiver.internal_representation ] for fv in federated_value_internals: py_typecheck.check_len(fv, len(rcv_children)) ### Materialize decryptor type_spec & function definition input_type = val.type_signature.member # input_type[0] is a tff.TensorType, thus input_type represents the # tuple needed for a single value to be decrypted. py_typecheck.check_type(input_type[0], tff.TensorType) py_typecheck.check_type(pk_snd_type, tff.TensorType) input_element_type = input_type pk_element_type = pk_snd_type decryptor_arg_spec = (input_element_type, pk_element_type, sk_rcv_type) decryptor_proto, decryptor_type = utils.materialize_computation_from_cache( sodium_comp.make_decryptor, self._decryptor_cache, decryptor_arg_spec, orig_tensor_dtype=self._input_type_cache.dtype) ### Decrypt values and return them decryptor_fns = asyncio.gather(*[ rcv_child.create_value(decryptor_proto, decryptor_type) for rcv_child in rcv_children ]) decryptor_args = asyncio.gather(*[ rcv_child.create_struct([v, pk, sk]) for v, pk, sk, rcv_child in zip(*federated_value_internals, rcv_children) ]) decryptor_fns, decryptor_args = await asyncio.gather( decryptor_fns, decryptor_args) decrypted_values = [ rcv_child.create_call(decryptor, arg) for decryptor, arg, rcv_child in zip(decryptor_fns, decryptor_args, rcv_children) ] return federated_resolving_strategy.FederatedResolvingStrategyValue( await asyncio.gather(*decrypted_values), tff.FederatedType(decryptor_type.result, receiver, all_equal=val.type_signature.all_equal))
async def compute_federated_select( self, arg: FederatedComposingStrategyValue ) -> FederatedComposingStrategyValue: client_keys_type, max_key_type, server_val_type, select_fn_type = ( arg.type_signature) del client_keys_type # Unused py_typecheck.check_type(arg.internal_representation, structure.Struct) client_keys, max_key, server_val, select_fn = arg.internal_representation py_typecheck.check_type(client_keys, list) py_typecheck.check_len(client_keys, len(self._target_executors)) py_typecheck.check_type(max_key, executor_value_base.ExecutorValue) py_typecheck.check_type(server_val, executor_value_base.ExecutorValue) py_typecheck.check_type(select_fn, pb.Computation) unplaced_server_val, unplaced_max_key = await asyncio.gather( server_val.compute(), max_key.compute()) select_type = computation_types.FunctionType( arg.type_signature, computation_types.at_clients( computation_types.SequenceType(select_fn_type.result))) select_pb = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_SELECT, select_type) async def child_fn(child, child_client_keys): child_max_key_fut = child.create_value(unplaced_max_key, max_key_type) child_server_val_fut = child.create_value(unplaced_server_val, server_val_type) child_select_fn_fut = child.create_value(select_fn, select_fn_type) child_max_key, child_server_val, child_select_fn = await asyncio.gather( child_max_key_fut, child_server_val_fut, child_select_fn_fut) child_fn_fut = child.create_value(select_pb, select_type) child_arg_fut = child.create_struct( structure.Struct([(None, child_client_keys), (None, child_max_key), (None, child_server_val), (None, child_select_fn)])) child_fn, child_arg = await asyncio.gather(child_fn_fut, child_arg_fut) return await child.create_call(child_fn, child_arg) return FederatedComposingStrategyValue( list(await asyncio.gather(*[ child_fn(ex, ex_keys) for (ex, ex_keys) in zip(self._target_executors, client_keys) ])), computation_types.at_clients( computation_types.SequenceType(select_fn_type.result)))
async def _compute_intrinsic_federated_apply(self, arg): py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 2) fn_type = arg.type_signature[0] py_typecheck.check_type(fn_type, computation_types.FunctionType) val_type = arg.type_signature[1] type_utils.check_federated_type( val_type, fn_type.parameter, placement_literals.SERVER, all_equal=True) fn = arg.internal_representation[0] val = arg.internal_representation[1] py_typecheck.check_type(fn, pb.Computation) py_typecheck.check_type(val, executor_value_base.ExecutorValue) return CompositeValue( await self._parent_executor.create_call( await self._parent_executor.create_value(fn, fn_type), val), type_factory.at_server(fn_type.result))
async def _decrypt_values_on_singleton(self, val, sender, receiver): ### Check proper key placement pk_sender = self.key_references.get_public_key(sender) sk_receiver = self.key_references.get_secret_key(receiver) type_analysis.check_federated_type(pk_sender.type_signature, placement=receiver) type_analysis.check_federated_type(sk_receiver.type_signature, placement=receiver) pk_snd_type = pk_sender.type_signature.member sk_rcv_type = sk_receiver.type_signature.member ### Check placement cardinalities snd_children = self.strategy._get_child_executors(sender) rcv_children = self.strategy._get_child_executors(receiver) py_typecheck.check_len(rcv_children, 1) rcv_child = rcv_children[0] ### Check value cardinalities py_typecheck.check_len(pk_sender.internal_representation, len(snd_children)) py_typecheck.check_len(sk_receiver.internal_representation, 1) ### Materialize decryptor type_spec & function definition py_typecheck.check_type(val.type_signature, tff.StructType) type_analysis.check_federated_type(val.type_signature[0], placement=receiver, all_equal=True) input_type = val.type_signature[0].member # each input_type is a tuple needed for one value to be decrypted py_typecheck.check_type(input_type, tff.StructType) py_typecheck.check_type(pk_snd_type, tff.StructType) py_typecheck.check_len(val.type_signature, len(pk_snd_type)) input_element_type = input_type pk_element_type = pk_snd_type[0] decryptor_arg_spec = (input_element_type, pk_element_type, sk_rcv_type) decryptor_proto, decryptor_type = utils.materialize_computation_from_cache( sodium_comp.make_decryptor, self._decryptor_cache, decryptor_arg_spec, orig_tensor_dtype=self._input_type_cache.dtype) ### Decrypt values and return them vals = val.internal_representation sk = sk_receiver.internal_representation[0] decryptor_fn = await rcv_child.create_value(decryptor_proto, decryptor_type) decryptor_args = await asyncio.gather(*[ rcv_child.create_struct([v, pk, sk]) for v, pk in zip(vals, pk_sender.internal_representation) ]) decrypted_values = await asyncio.gather(*[ rcv_child.create_call(decryptor_fn, arg) for arg in decryptor_args ]) decrypted_value_types = [decryptor_type.result] * len(decrypted_values) return federated_resolving_strategy.FederatedResolvingStrategyValue( structure.from_container(decrypted_values), tff.StructType([ tff.FederatedType(dvt, receiver, all_equal=True) for dvt in decrypted_value_types ]))
async def compute_federated_apply( self, arg: FederatedComposingStrategyValue) -> FederatedComposingStrategyValue: py_typecheck.check_type(arg.internal_representation, structure.Struct) py_typecheck.check_len(arg.internal_representation, 2) fn_type = arg.type_signature[0] py_typecheck.check_type(fn_type, computation_types.FunctionType) val_type = arg.type_signature[1] type_analysis.check_federated_type( val_type, fn_type.parameter, placement_literals.SERVER, all_equal=True) fn = arg.internal_representation[0] py_typecheck.check_type(fn, pb.Computation) val = arg.internal_representation[1] py_typecheck.check_type(val, executor_value_base.ExecutorValue) return FederatedComposingStrategyValue( await self._server_executor.create_call( await self._server_executor.create_value(fn, fn_type), val), computation_types.at_server(fn_type.result))
async def compute_federated_zip_at_server( self, arg: FederatedComposingStrategyValue) -> FederatedComposingStrategyValue: py_typecheck.check_type(arg.type_signature, computation_types.StructType) py_typecheck.check_len(arg.type_signature, 2) py_typecheck.check_type(arg.internal_representation, structure.Struct) py_typecheck.check_len(arg.internal_representation, 2) for n in [0, 1]: type_analysis.check_federated_type( arg.type_signature[n], placement=placement_literals.SERVER, all_equal=True) return FederatedComposingStrategyValue( await self._server_executor.create_struct( [arg.internal_representation[n] for n in [0, 1]]), computation_types.at_server( computation_types.StructType( [arg.type_signature[0].member, arg.type_signature[1].member])))
async def _compute_intrinsic_federated_zip_at_server(self, arg): py_typecheck.check_type(arg.type_signature, computation_types.NamedTupleType) py_typecheck.check_len(arg.type_signature, 2) py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 2) for n in [0, 1]: type_utils.check_federated_type( arg.type_signature[n], placement=placement_literals.SERVER, all_equal=True) return CompositeValue( await self._parent_executor.create_tuple( [arg.internal_representation[n] for n in [0, 1]]), type_factory.at_server( computation_types.NamedTupleType( [arg.type_signature[0].member, arg.type_signature[1].member])))
def _check_key_inputter(fn_value): fn_type = fn_value.type_signature py_typecheck.check_type(fn_type, tff.FunctionType) try: py_typecheck.check_len(fn_type.result, 2) except ValueError: raise ValueError('Expected 2 elements in the output of key_inputter, ' 'found {}.'.format(len(fn_type.result))) ek_type, dk_type = fn_type.result py_typecheck.check_type(ek_type, tff.TensorType) py_typecheck.check_type(dk_type, tff.StructType) try: py_typecheck.check_len(dk_type, 2) except ValueError: raise ValueError( 'Expected a two element tuple for the decryption key from ' 'key_inputter, found {} elements.'.format(len(fn_type.result))) py_typecheck.check_type(dk_type[0], tff.TensorType) py_typecheck.check_type(dk_type[1], tff.TensorType)
async def compute_federated_aggregate( self, arg: FederatedResolvingStrategyValue ) -> FederatedResolvingStrategyValue: val_type, zero_type, accumulate_type, _, report_type = ( executor_utils.parse_federated_aggregate_argument_types( arg.type_signature)) py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 5) # Note: This is a simple initial implementation that simply forwards this # to `federated_reduce()`. The more complete implementation would be able # to take advantage of the parallelism afforded by `merge` to reduce the # cost from liner (with respect to the number of clients) to sub-linear. # TODO(b/134543154): Expand this implementation to take advantage of the # parallelism afforded by `merge`. val = arg.internal_representation[0] zero = arg.internal_representation[1] accumulate = arg.internal_representation[2] pre_report = await self.compute_federated_reduce( FederatedResolvingStrategyValue( anonymous_tuple.AnonymousTuple([(None, val), (None, zero), (None, accumulate)]), computation_types.NamedTupleType( (val_type, zero_type, accumulate_type)))) py_typecheck.check_type(pre_report.type_signature, computation_types.FederatedType) pre_report.type_signature.member.check_equivalent_to( report_type.parameter) report = arg.internal_representation[4] return await self.compute_federated_apply( FederatedResolvingStrategyValue( anonymous_tuple.AnonymousTuple([ (None, report), (None, pre_report.internal_representation) ]), computation_types.NamedTupleType( (report_type, pre_report.type_signature))))
async def compute_federated_secure_sum(self, arg): self._check_arg_is_structure(arg) py_typecheck.check_len(arg.internal_representation, 2) value_type = arg.type_signature[0] type_analysis.check_federated_type(value_type, placement=tff.CLIENTS) py_typecheck.check_type(value_type.member, tff.TensorType) # Stash input dtype for later input_tensor_dtype = value_type.member.dtype # Paillier setup phase if self._requires_setup: await self._paillier_setup() self._requires_setup = False # Stash input shape, and reshape input tensor to matrix-form input_tensor_shape = value_type.member.shape if len(input_tensor_shape) != 2: clients_value = await self._compute_reshape_on_tensor( await self._executor.create_selection(arg, index=0), output_shape=[1, input_tensor_shape.num_elements()]) else: clients_value = await self._executor.create_selection(arg, index=0) # Encrypt summands on tff.CLIENTS encrypted_values = await self._compute_paillier_encryption( self.encryption_key_clients, clients_value) # Perform Paillier sum on ciphertexts encrypted_values = await self._move(encrypted_values, tff.CLIENTS, paillier_placement.AGGREGATOR) encrypted_sum = await self._compute_paillier_sum( self.encryption_key_paillier, encrypted_values) # Move to server and decrypt the result encrypted_sum = await self._move(encrypted_sum, paillier_placement.AGGREGATOR, tff.SERVER) decrypted_result = await self._compute_paillier_decryption( self.decryption_key, self.encryption_key_server, encrypted_sum, export_dtype=input_tensor_dtype) return await self._compute_reshape_on_tensor( decrypted_result, output_shape=input_tensor_shape.as_list())
async def _compute_reshape_on_tensor(self, tensor, output_shape): tensor_type = tensor.type_signature.member shape_type = type_conversions.infer_type(output_shape) reshaper_proto, reshaper_type = utils.materialize_computation_from_cache( paillier_comp.make_reshape_tensor, self._reshape_function_cache, arg_spec=(tensor_type, ), output_shape=output_shape) tensor_placement = tensor.type_signature.placement children = self._get_child_executors(tensor_placement) py_typecheck.check_len(tensor.internal_representation, len(children)) reshaper_fns = await asyncio.gather(*[ ex.create_value(reshaper_proto, reshaper_type) for ex in children ]) reshaped_tensors = await asyncio.gather(*[ ex.create_call(fn, arg) for ex, fn, arg in zip( children, reshaper_fns, tensor.internal_representation) ]) output_tensor_spec = tff.FederatedType( tff.TensorType(tensor_type.dtype, output_shape), tensor_placement, tensor.type_signature.all_equal) return federated_resolving_strategy.FederatedResolvingStrategyValue( reshaped_tensors, output_tensor_spec)
async def _compute_paillier_encryption( self, client_encryption_keys: federated_resolving_strategy. FederatedResolvingStrategyValue, clients_value: federated_resolving_strategy. FederatedResolvingStrategyValue): client_children = self._get_child_executors(tff.CLIENTS) num_clients = len(client_children) py_typecheck.check_len(client_encryption_keys.internal_representation, num_clients) py_typecheck.check_len(clients_value.internal_representation, num_clients) encryptor_proto, encryptor_type = utils.lift_to_computation_spec( self._paillier_encryptor, input_arg_type=tff.StructType( (client_encryption_keys.type_signature.member, clients_value.type_signature.member))) encryptor_fns = asyncio.gather(*[ c.create_value(encryptor_proto, encryptor_type) for c in client_children ]) encryptor_args = asyncio.gather(*[ c.create_struct((ek, v)) for c, ek, v in zip(client_children, client_encryption_keys.internal_representation, clients_value.internal_representation) ]) encryptor_fns, encryptor_args = await asyncio.gather( encryptor_fns, encryptor_args) encrypted_values = await asyncio.gather(*[ c.create_call(fn, arg) for c, fn, arg in zip( client_children, encryptor_fns, encryptor_args) ]) return federated_resolving_strategy.FederatedResolvingStrategyValue( encrypted_values, tff.FederatedType(encryptor_type.result, tff.CLIENTS, clients_value.type_signature.all_equal))
async def _share_public_key(self, key_owner, key_receiver): public_key = self.key_references.get_public_key(key_owner) children = self.strategy._get_child_executors(key_receiver) val = await public_key.compute() key_type = public_key.type_signature.member # we currently only support sharing n keys with 1 executor, # or sharing 1 key with n executors if isinstance(val, list): # sharing n keys with 1 executor py_typecheck.check_len(children, 1) executor = children[0] vals = [executor.create_value(v, key_type) for v in val] vals_type = tff.FederatedType(type_conversions.infer_type(val), key_receiver) else: # sharing 1 key with n executors # val is a single tensor vals = [c.create_value(val, key_type) for c in children] vals_type = tff.FederatedType(key_type, key_receiver, all_equal=True) public_key_rcv = federated_resolving_strategy.FederatedResolvingStrategyValue( await asyncio.gather(*vals), vals_type) self.key_references.update_keys(key_owner, public_key=public_key_rcv)
async def compute_federated_zip_at_clients( self, arg: FederatedComposingStrategyValue ) -> FederatedComposingStrategyValue: py_typecheck.check_type(arg.type_signature, computation_types.StructType) py_typecheck.check_len(arg.type_signature, 2) py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 2) keys = [k for k, _ in anonymous_tuple.to_elements(arg.type_signature)] vals = [arg.internal_representation[n] for n in [0, 1]] types = [arg.type_signature[n] for n in [0, 1]] for n in [0, 1]: type_analysis.check_federated_type( types[n], placement=placement_literals.CLIENTS) types[n] = type_factory.at_clients(types[n].member) py_typecheck.check_type(vals[n], list) py_typecheck.check_len(vals[n], len(self._target_executors)) item_type = computation_types.StructType([ ((keys[n], types[n].member) if keys[n] else types[n].member) for n in [0, 1] ]) result_type = type_factory.at_clients(item_type) zip_type = computation_types.FunctionType( computation_types.StructType([ ((keys[n], types[n]) if keys[n] else types[n]) for n in [0, 1] ]), result_type) zip_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip_type) async def _child_fn(ex, x, y): py_typecheck.check_type(x, executor_value_base.ExecutorValue) py_typecheck.check_type(y, executor_value_base.ExecutorValue) return await ex.create_call( await ex.create_value(zip_comp, zip_type), await ex.create_struct( anonymous_tuple.AnonymousTuple([(keys[0], x), (keys[1], y)]))) result = await asyncio.gather(*[ _child_fn(c, x, y) for c, x, y in zip(self._target_executors, vals[0], vals[1]) ]) return FederatedComposingStrategyValue(result, result_type)
async def create_value(self, value, type_spec=None): type_spec = computation_types.to_type(type_spec) py_typecheck.check_type(type_spec, computation_types.Type) if isinstance(value, intrinsic_defs.IntrinsicDef): if not type_utils.is_concrete_instance_of(type_spec, value.type_signature): # pytype: disable=attribute-error raise TypeError('Incompatible type {} used with intrinsic {}.'.format( type_spec, value.uri)) # pytype: disable=attribute-error else: return CompositeValue(value, type_spec) elif isinstance(value, pb.Computation): which_computation = value.WhichOneof('computation') if which_computation in ['tensorflow', 'lambda']: return CompositeValue(value, type_spec) elif which_computation == 'intrinsic': intr = intrinsic_defs.uri_to_intrinsic_def(value.intrinsic.uri) if intr is None: raise ValueError('Encountered an unrecognized intrinsic "{}".'.format( value.intrinsic.uri)) py_typecheck.check_type(intr, intrinsic_defs.IntrinsicDef) return await self.create_value(intr, type_spec) else: raise NotImplementedError( 'Unimplemented computation type {}.'.format(which_computation)) elif isinstance(type_spec, computation_types.NamedTupleType): v_el = anonymous_tuple.to_elements(anonymous_tuple.from_container(value)) t_el = anonymous_tuple.to_elements(type_spec) items = await asyncio.gather( *[self.create_value(v, t) for (_, v), (_, t) in zip(v_el, t_el)]) return self.create_tuple( anonymous_tuple.AnonymousTuple([ (k, i) for (k, _), i in zip(t_el, items) ])) elif isinstance(type_spec, computation_types.FederatedType): if type_spec.placement == placement_literals.SERVER: if type_spec.all_equal: return CompositeValue( await self._parent_executor.create_value(value, type_spec.member), type_spec) else: raise ValueError('A non-all_equal value on the server is unexpected.') elif type_spec.placement == placement_literals.CLIENTS: if type_spec.all_equal: return CompositeValue( await asyncio.gather(*[ c.create_value(value, type_spec) for c in self._child_executors ]), type_spec) else: py_typecheck.check_type(value, list) if self._cardinalities is None: self._cardinalities = asyncio.ensure_future( self._get_cardinalities()) cardinalities = await self._cardinalities py_typecheck.check_len(cardinalities, len(self._child_executors)) count = sum(cardinalities) py_typecheck.check_len(value, count) result = [] offset = 0 for c, n in zip(self._child_executors, cardinalities): new_offset = offset + n # The slice opporator is not supported on all the types `value` # supports. # pytype: disable=unsupported-operands result.append(c.create_value(value[offset:new_offset], type_spec)) # pytype: enable=unsupported-operands offset = new_offset return CompositeValue(await asyncio.gather(*result), type_spec) else: raise ValueError('Unexpected placement {}.'.format(type_spec.placement)) else: return CompositeValue( await self._parent_executor.create_value(value, type_spec), type_spec)
def __init__(self, initialize, prepare, work, zero, accumulate, merge, report, bitwidth, update, server_state_label=None, client_data_label=None): """Constructs a representation of a MapReduce-like iterative process. Note: All the computations supplied here as arguments must be TensorFlow computations, i.e., instances of `tff.Computation` constructed by the `tff.tf_computation` decorator/wrapper. Args: initialize: The computation that produces the initial server state. prepare: The computation that prepares the input for the clients. work: The client-side work computation. zero: The computation that produces the initial state for accumulators. accumulate: The computation that adds a client update to an accumulator. merge: The computation to use for merging pairs of accumulators. report: The computation that produces the final server-side aggregate for the top level accumulator (the global update). bitwidth: The computation that produces the bitwidth for secure sum. update: The computation that takes the global update and the server state and produces the new server state, as well as server-side output. server_state_label: Optional string label for the server state. client_data_label: Optional string label for the client data. Raises: TypeError: If the Python or TFF types of the arguments are invalid or not compatible with each other. AssertionError: If the manner in which the given TensorFlow computations are represented by TFF does not match what this code is expecting (this is an internal error that requires code update). """ for label, comp in [ ('initialize', initialize), ('prepare', prepare), ('work', work), ('zero', zero), ('accumulate', accumulate), ('merge', merge), ('report', report), ('bitwidth', bitwidth), ('update', update), ]: py_typecheck.check_type(comp, computation_base.Computation, label) # TODO(b/130633916): Remove private access once an appropriate API for it # becomes available. comp_proto = comp._computation_proto # pylint: disable=protected-access if not isinstance(comp_proto, computation_pb2.Computation): # Explicitly raised to force it to be done in non-debug mode as well. raise AssertionError( 'Cannot find the embedded computation definition.') which_comp = comp_proto.WhichOneof('computation') if which_comp != 'tensorflow': raise TypeError( 'Expected all computations supplied as arguments to ' 'be plain TensorFlow, found {}.'.format(which_comp)) def is_assignable_from_or_both_none(first, second): if first is None: return second is None return first.is_assignable_from(second) prepare_arg_type = prepare.type_signature.parameter init_result_type = initialize.type_signature.result if not is_assignable_from_or_both_none(prepare_arg_type, init_result_type): raise TypeError( 'The `prepare` computation expects an argument of type {}, ' 'which does not match the result type {} of `initialize`.'. format(prepare_arg_type, init_result_type)) if (not work.type_signature.parameter.is_struct() or len(work.type_signature.parameter) != 2): raise TypeError( 'The `work` computation expects an argument of type {} that is not ' 'a two-tuple.'.format(work.type_signature.parameter)) work_2nd_arg_type = work.type_signature.parameter[1] prepare_result_type = prepare.type_signature.result if not is_assignable_from_or_both_none(work_2nd_arg_type, prepare_result_type): raise TypeError( 'The `work` computation expects an argument tuple with type {} as ' 'the second element (the initial client state from the server), ' 'which does not match the result type {} of `prepare`.'.format( work_2nd_arg_type, prepare_result_type)) if (not work.type_signature.result.is_struct() or len(work.type_signature.result) != 2): raise TypeError( 'The `work` computation returns a result of type {} that is not a ' 'two-tuple.'.format(work.type_signature.result)) py_typecheck.check_type(zero.type_signature, computation_types.FunctionType) py_typecheck.check_type(accumulate.type_signature, computation_types.FunctionType) py_typecheck.check_len(accumulate.type_signature.parameter, 2) accumulate.type_signature.parameter[0].check_assignable_from( zero.type_signature.result) accumulate_2nd_arg_type = accumulate.type_signature.parameter[1] work_client_update_type = work.type_signature.result[0] if not is_assignable_from_or_both_none(accumulate_2nd_arg_type, work_client_update_type): raise TypeError( 'The `accumulate` computation expects a second argument of type {}, ' 'which does not match the expected {} as implied by the type ' 'signature of `work`.'.format(accumulate_2nd_arg_type, work_client_update_type)) accumulate.type_signature.parameter[0].check_assignable_from( accumulate.type_signature.result) py_typecheck.check_type(merge.type_signature, computation_types.FunctionType) py_typecheck.check_len(merge.type_signature.parameter, 2) merge.type_signature.parameter[0].check_assignable_from( accumulate.type_signature.result) merge.type_signature.parameter[1].check_assignable_from( accumulate.type_signature.result) merge.type_signature.parameter[0].check_assignable_from( merge.type_signature.result) py_typecheck.check_type(report.type_signature, computation_types.FunctionType) report.type_signature.parameter.check_assignable_from( merge.type_signature.result) py_typecheck.check_type(bitwidth.type_signature, computation_types.FunctionType) expected_update_parameter_type = computation_types.to_type([ initialize.type_signature.result, [report.type_signature.result, work.type_signature.result[1]], ]) if not is_assignable_from_or_both_none(update.type_signature.parameter, expected_update_parameter_type): raise TypeError( 'The `update` computation expects an argument of type {}, ' 'which does not match the expected {} as implied by the type ' 'signatures of `initialize`, `report`, and `work`.'.format( update.type_signature.parameter, expected_update_parameter_type)) if (not update.type_signature.result.is_struct() or len(update.type_signature.result) != 2): raise TypeError( 'The `update` computation returns a result of type {} that is not ' 'a two-tuple.'.format(update.type_signature.result)) updated_state_type = update.type_signature.result[0] if not prepare_arg_type.is_assignable_from(updated_state_type): raise TypeError( 'The `update` computation returns a result tuple whose first element ' f'(the updated state type of the server) is type:\n' f'{updated_state_type}\n' f'which is not assignable to the state parameter type of `prepare`:\n' f'{prepare_arg_type}') self._initialize = initialize self._prepare = prepare self._work = work self._zero = zero self._accumulate = accumulate self._merge = merge self._report = report self._bitwidth = bitwidth self._update = update if server_state_label is not None: py_typecheck.check_type(server_state_label, str) self._server_state_label = server_state_label if client_data_label is not None: py_typecheck.check_type(client_data_label, str) self._client_data_label = client_data_label