async def _compute_intrinsic_federated_map(self, arg): py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 2) fn_type = arg.type_signature[0] py_typecheck.check_type(fn_type, computation_types.FunctionType) val_type = arg.type_signature[1] type_utils.check_federated_type(val_type, fn_type.parameter, placement_literals.CLIENTS) fn = arg.internal_representation[0] val = arg.internal_representation[1] py_typecheck.check_type(fn, pb.Computation) py_typecheck.check_type(val, list) map_type = computation_types.FunctionType( [fn_type, type_factory.at_clients(fn_type.parameter)], type_factory.at_clients(fn_type.result)) map_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_MAP, map_type) async def _child_fn(ex, v): py_typecheck.check_type(v, executor_value_base.ExecutorValue) fn_val = await ex.create_value(fn, fn_type) map_val, map_arg = tuple(await asyncio.gather( ex.create_value(map_comp, map_type), ex.create_tuple([fn_val, v]))) return await ex.create_call(map_val, map_arg) result_vals = await asyncio.gather( *[_child_fn(c, v) for c, v in zip(self._child_executors, val)]) return CompositeValue(result_vals, type_factory.at_clients(fn_type.result))
async def _eval(self, arg, intrinsic, placement, all_equal): py_typecheck.check_type(arg.type_signature, computation_types.FunctionType) py_typecheck.check_type(arg.internal_representation, pb.Computation) py_typecheck.check_type(placement, placement_literals.PlacementLiteral) fn = arg.internal_representation fn_type = arg.type_signature eval_type = computation_types.FunctionType( fn_type, computation_types.FederatedType( fn_type.result, placement, all_equal=all_equal)) eval_comp = executor_utils.create_intrinsic_comp(intrinsic, eval_type) async def _child_fn(ex): py_typecheck.check_type(ex, executor_base.Executor) create_eval = ex.create_value(eval_comp, eval_type) create_fn = ex.create_value(fn, fn_type) eval_val, fn_val = tuple(await asyncio.gather(create_eval, create_fn)) return await ex.create_call(eval_val, fn_val) result_vals = await asyncio.gather( *[_child_fn(c) for c in self._child_executors]) result_type = computation_types.FederatedType( fn_type.result, placement, all_equal=all_equal) return CompositeValue(result_vals, result_type)
async def _get_cardinalities_helper(): """Helper function which does the actual work of fetching cardinalities.""" one_type = type_factory.at_clients(tf.int32, all_equal=True) sum_type = computation_types.FunctionType( type_factory.at_clients(tf.int32), type_factory.at_server(tf.int32)) sum_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_SUM, sum_type) async def _count_leaf_executors(ex): """Counts the total number of leaf executors under `ex`.""" one_fut = ex.create_value(1, one_type) sum_comp_fut = ex.create_value(sum_comp, sum_type) one_val, sum_comp_val = tuple(await asyncio.gather( one_fut, sum_comp_fut)) sum_result = await (await ex.create_call(sum_comp_val, one_val)).compute() if isinstance(sum_result, tf.Tensor): return sum_result.numpy() else: return sum_result return await asyncio.gather( *[_count_leaf_executors(c) for c in self._child_executors])
async def _map(self, arg, all_equal=None): py_typecheck.check_type(arg.internal_representation, structure.Struct) py_typecheck.check_len(arg.internal_representation, 2) fn_type = arg.type_signature[0] py_typecheck.check_type(fn_type, computation_types.FunctionType) val_type = arg.type_signature[1] py_typecheck.check_type(val_type, computation_types.FederatedType) if all_equal is None: all_equal = val_type.all_equal elif all_equal and not val_type.all_equal: raise ValueError( 'Cannot map a non-all_equal argument into an all_equal result.') fn = arg.internal_representation[0] py_typecheck.check_type(fn, pb.Computation) val = arg.internal_representation[1] py_typecheck.check_type(val, list) map_type = computation_types.FunctionType( [fn_type, computation_types.at_clients(fn_type.parameter)], computation_types.at_clients(fn_type.result)) map_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_MAP, map_type) async def _child_fn(ex, v): py_typecheck.check_type(v, executor_value_base.ExecutorValue) fn_val = await ex.create_value(fn, fn_type) map_val, map_arg = await asyncio.gather( ex.create_value(map_comp, map_type), ex.create_struct([fn_val, v])) return await ex.create_call(map_val, map_arg) result_vals = await asyncio.gather( *[_child_fn(c, v) for c, v in zip(self._target_executors, val)]) federated_type = computation_types.FederatedType( fn_type.result, val_type.placement, all_equal=all_equal) return FederatedComposingStrategyValue(result_vals, federated_type)
async def compute_federated_aggregate( self, arg: FederatedComposingStrategyValue ) -> FederatedComposingStrategyValue: value_type, zero_type, accumulate_type, merge_type, report_type = ( executor_utils.parse_federated_aggregate_argument_types( arg.type_signature)) py_typecheck.check_type(arg.internal_representation, structure.Struct) py_typecheck.check_len(arg.internal_representation, 5) val = arg.internal_representation[0] py_typecheck.check_type(val, list) py_typecheck.check_len(val, len(self._target_executors)) identity_report, identity_report_type = tensorflow_computation_factory.create_identity( zero_type) aggr_type = computation_types.FunctionType( computation_types.StructType([ value_type, zero_type, accumulate_type, merge_type, identity_report_type ]), computation_types.at_server(zero_type)) aggr_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_AGGREGATE, aggr_type) zero = await (await self._executor.create_selection(arg, index=1)).compute() accumulate = arg.internal_representation[2] merge = arg.internal_representation[3] report = arg.internal_representation[4] async def _child_fn(ex, v): py_typecheck.check_type(v, executor_value_base.ExecutorValue) arg_values = [ ex.create_value(zero, zero_type), ex.create_value(accumulate, accumulate_type), ex.create_value(merge, merge_type), ex.create_value(identity_report, identity_report_type) ] aggr_func, aggr_args = await asyncio.gather( ex.create_value(aggr_comp, aggr_type), ex.create_struct([v] + list(await asyncio.gather(*arg_values)))) child_result = await (await ex.create_call(aggr_func, aggr_args)).compute() result_at_server = await self._server_executor.create_value( child_result, zero_type) return result_at_server val_futures = asyncio.as_completed( [_child_fn(c, v) for c, v in zip(self._target_executors, val)]) parent_merge, parent_report = await asyncio.gather( self._server_executor.create_value(merge, merge_type), self._server_executor.create_value(report, report_type)) merge_result = await next(val_futures) for next_val_future in val_futures: next_val = await next_val_future merge_arg = await self._server_executor.create_struct( [merge_result, next_val]) merge_result = await self._server_executor.create_call( parent_merge, merge_arg) report_result = await self._server_executor.create_call( parent_report, merge_result) return FederatedComposingStrategyValue( report_result, computation_types.at_server(report_type.result))
async def compute_federated_aggregate( self, arg: FederatedComposingStrategyValue ) -> FederatedComposingStrategyValue: value_type, zero_type, accumulate_type, merge_type, report_type = ( executor_utils.parse_federated_aggregate_argument_types( arg.type_signature)) py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 5) val = arg.internal_representation[0] py_typecheck.check_type(val, list) py_typecheck.check_len(val, len(self._target_executors)) identity_report = tensorflow_computation_factory.create_identity( zero_type) identity_report_type = type_factory.unary_op(zero_type) aggr_type = computation_types.FunctionType( computation_types.NamedTupleType([ value_type, zero_type, accumulate_type, merge_type, identity_report_type ]), type_factory.at_server(zero_type)) aggr_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_AGGREGATE, aggr_type) zero = await (await self._executor.create_selection(arg, index=1)).compute() accumulate = arg.internal_representation[2] merge = arg.internal_representation[3] report = arg.internal_representation[4] async def _child_fn(ex, v): py_typecheck.check_type(v, executor_value_base.ExecutorValue) aggr_func, aggr_args = await asyncio.gather( ex.create_value(aggr_comp, aggr_type), ex.create_tuple([v] + list(await asyncio.gather( ex.create_value(zero, zero_type), ex.create_value(accumulate, accumulate_type), ex.create_value(merge, merge_type), ex.create_value(identity_report, identity_report_type))))) return await (await ex.create_call(aggr_func, aggr_args)).compute() vals = await asyncio.gather( *[_child_fn(c, v) for c, v in zip(self._target_executors, val)]) parent_vals = await asyncio.gather( *[self._server_executor.create_value(v, zero_type) for v in vals]) parent_merge, parent_report = await asyncio.gather( self._server_executor.create_value(merge, merge_type), self._server_executor.create_value(report, report_type)) merge_result = parent_vals[0] for next_val in parent_vals[1:]: merge_result = await self._server_executor.create_call( parent_merge, await self._server_executor.create_tuple([merge_result, next_val])) return FederatedComposingStrategyValue( await self._server_executor.create_call(parent_report, merge_result), type_factory.at_server(report_type.result))
async def _num_clients(executor): """Returns the number of clients for the given `executor`.""" intrinsic_type = computation_types.FunctionType( type_factory.at_clients(tf.int32), type_factory.at_server(tf.int32)) intrinsic = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_SUM, intrinsic_type) arg_type = type_factory.at_clients(tf.int32, all_equal=True) fn, arg = tuple(await asyncio.gather( executor.create_value(intrinsic, intrinsic_type), executor.create_value(1, arg_type))) call = await executor.create_call(fn, arg) result = await call.compute() if isinstance(result, tf.Tensor): return result.numpy() else: return result
async def compute_federated_select( self, arg: FederatedComposingStrategyValue ) -> FederatedComposingStrategyValue: client_keys_type, max_key_type, server_val_type, select_fn_type = ( arg.type_signature) del client_keys_type # Unused py_typecheck.check_type(arg.internal_representation, structure.Struct) client_keys, max_key, server_val, select_fn = arg.internal_representation py_typecheck.check_type(client_keys, list) py_typecheck.check_len(client_keys, len(self._target_executors)) py_typecheck.check_type(max_key, executor_value_base.ExecutorValue) py_typecheck.check_type(server_val, executor_value_base.ExecutorValue) py_typecheck.check_type(select_fn, pb.Computation) unplaced_server_val, unplaced_max_key = await asyncio.gather( server_val.compute(), max_key.compute()) select_type = computation_types.FunctionType( arg.type_signature, computation_types.at_clients( computation_types.SequenceType(select_fn_type.result))) select_pb = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_SELECT, select_type) async def child_fn(child, child_client_keys): child_max_key_fut = child.create_value(unplaced_max_key, max_key_type) child_server_val_fut = child.create_value(unplaced_server_val, server_val_type) child_select_fn_fut = child.create_value(select_fn, select_fn_type) child_max_key, child_server_val, child_select_fn = await asyncio.gather( child_max_key_fut, child_server_val_fut, child_select_fn_fut) child_fn_fut = child.create_value(select_pb, select_type) child_arg_fut = child.create_struct( structure.Struct([(None, child_client_keys), (None, child_max_key), (None, child_server_val), (None, child_select_fn)])) child_fn, child_arg = await asyncio.gather(child_fn_fut, child_arg_fut) return await child.create_call(child_fn, child_arg) return FederatedComposingStrategyValue( list(await asyncio.gather(*[ child_fn(ex, ex_keys) for (ex, ex_keys) in zip(self._target_executors, client_keys) ])), computation_types.at_clients( computation_types.SequenceType(select_fn_type.result)))
async def compute_federated_zip_at_clients( self, arg: FederatedComposingStrategyValue ) -> FederatedComposingStrategyValue: py_typecheck.check_type(arg.type_signature, computation_types.StructType) py_typecheck.check_len(arg.type_signature, 2) py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 2) keys = [k for k, _ in anonymous_tuple.to_elements(arg.type_signature)] vals = [arg.internal_representation[n] for n in [0, 1]] types = [arg.type_signature[n] for n in [0, 1]] for n in [0, 1]: type_analysis.check_federated_type( types[n], placement=placement_literals.CLIENTS) types[n] = type_factory.at_clients(types[n].member) py_typecheck.check_type(vals[n], list) py_typecheck.check_len(vals[n], len(self._target_executors)) item_type = computation_types.StructType([ ((keys[n], types[n].member) if keys[n] else types[n].member) for n in [0, 1] ]) result_type = type_factory.at_clients(item_type) zip_type = computation_types.FunctionType( computation_types.StructType([ ((keys[n], types[n]) if keys[n] else types[n]) for n in [0, 1] ]), result_type) zip_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip_type) async def _child_fn(ex, x, y): py_typecheck.check_type(x, executor_value_base.ExecutorValue) py_typecheck.check_type(y, executor_value_base.ExecutorValue) return await ex.create_call( await ex.create_value(zip_comp, zip_type), await ex.create_struct( anonymous_tuple.AnonymousTuple([(keys[0], x), (keys[1], y)]))) result = await asyncio.gather(*[ _child_fn(c, x, y) for c, x, y in zip(self._target_executors, vals[0], vals[1]) ]) return FederatedComposingStrategyValue(result, result_type)
async def _num_clients(executor): """Returns the number of clients for the given `executor`.""" # We implement num_clients as a federated_aggregate to allow for federated # op resolving strategies which implement only the minimal set of # intrinsics. int_at_clients_type = computation_types.at_clients(tf.int32) zero_type = tf.int32 accumulate_type = computation_types.FunctionType( computation_types.StructType([tf.int32, tf.int32]), tf.int32) merge_type = accumulate_type report_type = computation_types.FunctionType(tf.int32, tf.int32) intrinsic_type = computation_types.FunctionType( computation_types.StructType([ int_at_clients_type, zero_type, accumulate_type, merge_type, report_type ]), computation_types.at_server(tf.int32)) intrinsic = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_AGGREGATE, intrinsic_type) add_comp = building_block_factory.create_tensorflow_binary_operator_with_upcast( computation_types.StructType([tf.int32, tf.int32]), tf.add).proto identity_comp = building_block_factory.create_compiled_identity( computation_types.TensorType(tf.int32)).proto fn, client_data, zero_value, add_value, identity_value = await asyncio.gather( executor.create_value(intrinsic, intrinsic_type), executor.create_value( 1, computation_types.at_clients(tf.int32, all_equal=True)), executor.create_value(0, tf.int32), executor.create_value(add_comp), executor.create_value(identity_comp)) arg = await executor.create_struct([ client_data, zero_value, add_value, add_value, identity_value ]) call = await executor.create_call(fn, arg) result = await call.compute() if isinstance(result, tf.Tensor): return result.numpy() else: return result
async def compute_federated_zip_at_clients( self, arg: FederatedComposingStrategyValue ) -> FederatedComposingStrategyValue: py_typecheck.check_type(arg.type_signature, computation_types.StructType) result_type = computation_types.at_clients( type_transformations.strip_placement(arg.type_signature)) zip_type = computation_types.FunctionType(arg.type_signature, result_type) zip_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip_type) async def _child_fn(child, child_index): struct_value = await self._zip_struct_into_child( child, child_index, arg.internal_representation, arg.type_signature) return await child.create_call( await child.create_value(zip_comp, zip_type), struct_value) result = await asyncio.gather( *[_child_fn(c, i) for (i, c) in enumerate(self._target_executors)]) return FederatedComposingStrategyValue(result, result_type)