async def compute_federated_mean( self, arg: FederatedResolvingStrategyValue ) -> FederatedResolvingStrategyValue: arg_sum = await self.compute_federated_sum(arg) member_type = arg_sum.type_signature.member count = float(len(arg.internal_representation)) if count < 1.0: raise RuntimeError( 'Cannot compute a federated mean over an empty group.') child = self._target_executors[placement_literals.SERVER][0] factor, multiply = await asyncio.gather( executor_utils.embed_constant( child, member_type, float(1.0 / count), local_computation_factory=self._local_computation_factory), executor_utils.embed_multiply_operator( child, member_type, local_computation_factory=self._local_computation_factory)) multiply_arg = await child.create_struct( structure.Struct([(None, arg_sum.internal_representation[0]), (None, factor)])) result = await child.create_call(multiply, multiply_arg) return FederatedResolvingStrategyValue([result], arg_sum.type_signature)
async def compute_federated_sum( self, arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue: py_typecheck.check_type(arg.type_signature, computation_types.FederatedType) zero, plus = await asyncio.gather( executor_utils.embed_constant( self._executor, arg.type_signature.member, 0, local_computation_factory=self._local_computation_factory), executor_utils.embed_plus_operator( self._executor, arg.type_signature.member, local_computation_factory=self._local_computation_factory)) return await self.reduce(arg.internal_representation, zero, plus.internal_representation, plus.type_signature)
async def compute_federated_sum( self, arg: FederatedComposingStrategyValue ) -> FederatedComposingStrategyValue: type_analysis.check_federated_type( arg.type_signature, placement=placement_literals.CLIENTS) id_comp, id_type = tensorflow_computation_factory.create_identity( arg.type_signature.member) zero, plus, identity = await asyncio.gather( executor_utils.embed_constant( self._executor, arg.type_signature.member, 0, local_computation_factory=self._local_computation_factory), executor_utils.embed_plus_operator( self._executor, arg.type_signature.member, local_computation_factory=self._local_computation_factory), self._executor.create_value(id_comp, id_type)) aggregate_args = await self._executor.create_struct( [arg, zero, plus, plus, identity]) return await self.compute_federated_aggregate(aggregate_args)