async def _compute_intrinsic_federated_map(self, arg): py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 2) fn_type = arg.type_signature[0] py_typecheck.check_type(fn_type, computation_types.FunctionType) val_type = arg.type_signature[1] type_utils.check_federated_type(val_type, fn_type.parameter, placement_literals.CLIENTS) fn = arg.internal_representation[0] val = arg.internal_representation[1] py_typecheck.check_type(fn, pb.Computation) py_typecheck.check_type(val, list) map_type = computation_types.FunctionType( [fn_type, type_factory.at_clients(fn_type.parameter)], type_factory.at_clients(fn_type.result)) map_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_MAP, map_type) async def _child_fn(ex, v): py_typecheck.check_type(v, executor_value_base.ExecutorValue) fn_val = await ex.create_value(fn, fn_type) map_val, map_arg = tuple(await asyncio.gather( ex.create_value(map_comp, map_type), ex.create_tuple([fn_val, v]))) return await ex.create_call(map_val, map_arg) result_vals = await asyncio.gather( *[_child_fn(c, v) for c, v in zip(self._child_executors, val)]) return CompositeValue(result_vals, type_factory.at_clients(fn_type.result))
async def _eval(self, arg, intrinsic, placement, all_equal): py_typecheck.check_type(arg.type_signature, computation_types.FunctionType) py_typecheck.check_type(arg.internal_representation, pb.Computation) py_typecheck.check_type(placement, placement_literals.PlacementLiteral) fn = arg.internal_representation fn_type = arg.type_signature eval_type = computation_types.FunctionType( fn_type, computation_types.FederatedType( fn_type.result, placement, all_equal=all_equal)) eval_comp = executor_utils.create_intrinsic_comp(intrinsic, eval_type) async def _child_fn(ex): py_typecheck.check_type(ex, executor_base.Executor) create_eval = ex.create_value(eval_comp, eval_type) create_fn = ex.create_value(fn, fn_type) eval_val, fn_val = tuple(await asyncio.gather(create_eval, create_fn)) return await ex.create_call(eval_val, fn_val) result_vals = await asyncio.gather( *[_child_fn(c) for c in self._child_executors]) result_type = computation_types.FederatedType( fn_type.result, placement, all_equal=all_equal) return CompositeValue(result_vals, result_type)
async def _compute_intrinsic_federated_aggregate(self, arg): value_type, zero_type, accumulate_type, merge_type, report_type = ( executor_utils.parse_federated_aggregate_argument_types( arg.type_signature)) py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 5) val = arg.internal_representation[0] py_typecheck.check_type(val, list) py_typecheck.check_len(val, len(self._child_executors)) identity_report = _create_lambda_identity_comp(zero_type) identity_report_type = type_factory.unary_op(zero_type) aggr_type = computation_types.FunctionType( computation_types.NamedTupleType([ value_type, zero_type, accumulate_type, merge_type, identity_report_type ]), type_factory.at_server(zero_type)) aggr_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_AGGREGATE, aggr_type) zero = await (await self.create_selection(arg, index=1)).compute() accumulate = arg.internal_representation[2] merge = arg.internal_representation[3] report = arg.internal_representation[4] async def _child_fn(ex, v): py_typecheck.check_type(v, executor_value_base.ExecutorValue) aggr_func, aggr_args = tuple(await asyncio.gather( ex.create_value(aggr_comp, aggr_type), ex.create_tuple([v] + list(await asyncio.gather( ex.create_value(zero, zero_type), ex.create_value(accumulate, accumulate_type), ex.create_value(merge, merge_type), ex.create_value(identity_report, identity_report_type)))))) return await (await ex.create_call(aggr_func, aggr_args)).compute() vals = await asyncio.gather( *[_child_fn(c, v) for c, v in zip(self._child_executors, val)]) parent_vals = await asyncio.gather( *[self._parent_executor.create_value(v, zero_type) for v in vals]) parent_merge, parent_report = tuple(await asyncio.gather( self._parent_executor.create_value(merge, merge_type), self._parent_executor.create_value(report, report_type))) merge_result = parent_vals[0] for next_val in parent_vals[1:]: merge_result = await self._parent_executor.create_call( parent_merge, await self._parent_executor.create_tuple([merge_result, next_val])) return CompositeValue( await self._parent_executor.create_call(parent_report, merge_result), type_factory.at_server(report_type.result))
async def _get_cardinalities(self): one_type = type_factory.at_clients(tf.int32, all_equal=True) sum_type = computation_types.FunctionType( type_factory.at_clients(tf.int32), type_factory.at_server(tf.int32)) sum_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_SUM, sum_type) async def _child_fn(ex): return await (await ex.create_call(*(await asyncio.gather( ex.create_value(sum_comp, sum_type), ex.create_value(1, one_type)))) ).compute() def _materialize(v): return v.numpy() if isinstance(v, tf.Tensor) else v return [ _materialize(x) for x in (await asyncio.gather( *[_child_fn(c) for c in self._child_executors])) ]
async def _compute_intrinsic_federated_zip_at_clients(self, arg): py_typecheck.check_type(arg.type_signature, computation_types.NamedTupleType) py_typecheck.check_len(arg.type_signature, 2) py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 2) keys = [k for k, _ in anonymous_tuple.to_elements(arg.type_signature)] vals = [arg.internal_representation[n] for n in [0, 1]] types = [arg.type_signature[n] for n in [0, 1]] for n in [0, 1]: type_utils.check_federated_type( types[n], placement=placement_literals.CLIENTS) types[n] = type_factory.at_clients(types[n].member) py_typecheck.check_type(vals[n], list) py_typecheck.check_len(vals[n], len(self._child_executors)) item_type = computation_types.NamedTupleType([ ((keys[n], types[n].member) if keys[n] else types[n].member) for n in [0, 1] ]) result_type = type_factory.at_clients(item_type) zip_type = computation_types.FunctionType( computation_types.NamedTupleType([ ((keys[n], types[n]) if keys[n] else types[n]) for n in [0, 1] ]), result_type) zip_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip_type) async def _child_fn(ex, x, y): py_typecheck.check_type(x, executor_value_base.ExecutorValue) py_typecheck.check_type(y, executor_value_base.ExecutorValue) return await ex.create_call( await ex.create_value(zip_comp, zip_type), await ex.create_tuple( anonymous_tuple.AnonymousTuple([(keys[0], x), (keys[1], y)]))) result = await asyncio.gather(*[ _child_fn(c, x, y) for c, x, y in zip(self._child_executors, vals[0], vals[1]) ]) return CompositeValue(result, result_type)
async def _get_cardinalities_helper(): """Helper function which does the actual work of fetching cardinalities.""" one_type = type_factory.at_clients(tf.int32, all_equal=True) sum_type = computation_types.FunctionType( type_factory.at_clients(tf.int32), type_factory.at_server(tf.int32)) sum_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_SUM, sum_type) async def _count_leaf_executors(ex): """Counts the total number of leaf executors under `ex`.""" one_fut = ex.create_value(1, one_type) sum_comp_fut = ex.create_value(sum_comp, sum_type) one_val, sum_comp_val = tuple(await asyncio.gather(one_fut, sum_comp_fut)) sum_result = await (await ex.create_call(sum_comp_val, one_val)).compute() if isinstance(sum_result, tf.Tensor): return sum_result.numpy() else: return sum_result return await asyncio.gather( *[_count_leaf_executors(c) for c in self._child_executors])