async def _compute_intrinsic_federated_sum(self, arg): type_utils.check_federated_type( arg.type_signature, placement=placement_literals.CLIENTS) zero, plus, identity = tuple(await asyncio.gather(*[ executor_utils.embed_tf_scalar_constant(self, arg.type_signature.member, 0), executor_utils.embed_tf_binary_operator(self, arg.type_signature.member, tf.add), self.create_value( _create_lambda_identity_comp(arg.type_signature.member), type_factory.unary_op(arg.type_signature.member)) ])) aggregate_args = await self.create_tuple([arg, zero, plus, plus, identity]) return await self._compute_intrinsic_federated_aggregate(aggregate_args)
async def _compute_intrinsic_federated_mean(self, arg): arg_sum = await self._compute_intrinsic_federated_sum(arg) member_type = arg_sum.type_signature.member count = float(len(arg.internal_representation)) if count < 1.0: raise RuntimeError('Cannot compute a federated mean over an empty group.') child = self._target_executors[placement_literals.SERVER][0] factor, multiply = tuple(await asyncio.gather(*[ executor_utils.embed_tf_scalar_constant(child, member_type, float(1.0 / count)), executor_utils.embed_tf_binary_operator(child, member_type, tf.multiply) ])) multiply_arg = await child.create_tuple( anonymous_tuple.AnonymousTuple([(None, arg_sum.internal_representation[0]), (None, factor)])) result = await child.create_call(multiply, multiply_arg) return FederatedExecutorValue([result], arg_sum.type_signature)
async def _compute_intrinsic_federated_sum(self, arg): py_typecheck.check_type(arg.type_signature, computation_types.FederatedType) zero, plus = tuple(await asyncio.gather(*[ executor_utils.embed_tf_scalar_constant( self, arg.type_signature.member, 0), executor_utils.embed_tf_binary_operator( self, arg.type_signature.member, tf.add) ])) return await self._compute_intrinsic_federated_reduce( FederatedExecutorValue( anonymous_tuple.AnonymousTuple([ (None, arg.internal_representation), (None, zero.internal_representation), (None, plus.internal_representation) ]), computation_types.NamedTupleType( [arg.type_signature, zero.type_signature, plus.type_signature])))
async def _compute_intrinsic_federated_mean(self, arg): member_type = arg.type_signature.member ones = await self.create_value( 1, type_factory.at_clients(member_type, all_equal=True)) totals = (await self._compute_intrinsic_federated_sum( await self._compute_intrinsic_federated_zip_at_clients( await self.create_tuple([arg, ones])))).internal_representation py_typecheck.check_type(totals, executor_value_base.ExecutorValue) fed_sum, count = tuple(await asyncio.gather( self._parent_executor.create_selection(totals, index=0), self._parent_executor.create_selection(totals, index=1))) count_val = await count.compute() factor, multiply = tuple(await asyncio.gather(*[ executor_utils.embed_tf_scalar_constant( self._parent_executor, member_type, float(1.0 / count_val)), executor_utils.embed_tf_binary_operator(self._parent_executor, member_type, tf.multiply) ])) multiply_arg = await self._parent_executor.create_tuple([fed_sum, factor]) result = await self._parent_executor.create_call(multiply, multiply_arg) return CompositeValue(result, type_factory.at_server(member_type))