예제 #1
0
 async def compute_federated_mean(
     self, arg: FederatedResolvingStrategyValue
 ) -> FederatedResolvingStrategyValue:
     arg_sum = await self.compute_federated_sum(arg)
     member_type = arg_sum.type_signature.member
     count = float(len(arg.internal_representation))
     if count < 1.0:
         raise RuntimeError(
             'Cannot compute a federated mean over an empty group.')
     child = self._target_executors[placement_literals.SERVER][0]
     factor, multiply = await asyncio.gather(
         executor_utils.embed_constant(
             child,
             member_type,
             float(1.0 / count),
             local_computation_factory=self._local_computation_factory),
         executor_utils.embed_multiply_operator(
             child,
             member_type,
             local_computation_factory=self._local_computation_factory))
     multiply_arg = await child.create_struct(
         structure.Struct([(None, arg_sum.internal_representation[0]),
                           (None, factor)]))
     result = await child.create_call(multiply, multiply_arg)
     return FederatedResolvingStrategyValue([result],
                                            arg_sum.type_signature)
예제 #2
0
 async def compute_federated_sum(
     self,
     arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
   py_typecheck.check_type(arg.type_signature, computation_types.FederatedType)
   zero, plus = await asyncio.gather(
       executor_utils.embed_constant(
           self._executor,
           arg.type_signature.member,
           0,
           local_computation_factory=self._local_computation_factory),
       executor_utils.embed_plus_operator(
           self._executor,
           arg.type_signature.member,
           local_computation_factory=self._local_computation_factory))
   return await self.reduce(arg.internal_representation, zero,
                            plus.internal_representation, plus.type_signature)
예제 #3
0
 async def compute_federated_sum(
     self, arg: FederatedComposingStrategyValue
 ) -> FederatedComposingStrategyValue:
     type_analysis.check_federated_type(
         arg.type_signature, placement=placement_literals.CLIENTS)
     id_comp, id_type = tensorflow_computation_factory.create_identity(
         arg.type_signature.member)
     zero, plus, identity = await asyncio.gather(
         executor_utils.embed_constant(
             self._executor,
             arg.type_signature.member,
             0,
             local_computation_factory=self._local_computation_factory),
         executor_utils.embed_plus_operator(
             self._executor,
             arg.type_signature.member,
             local_computation_factory=self._local_computation_factory),
         self._executor.create_value(id_comp, id_type))
     aggregate_args = await self._executor.create_struct(
         [arg, zero, plus, plus, identity])
     return await self.compute_federated_aggregate(aggregate_args)