Пример #1
0
async def compute_intrinsic_federated_broadcast(
    executor: executor_base.Executor, arg: executor_value_base.ExecutorValue
) -> executor_value_base.ExecutorValue:
    """Computes a federated broadcast on the given `executor`.

  Args:
    executor: The executor to use.
    arg: The value to broadcast. Expected to be embedded in the `executor` and
      have federated type placed at `tff.SERVER` with all_equal of `True`.

  Returns:
    The result embedded in `executor`.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
    py_typecheck.check_type(executor, executor_base.Executor)
    py_typecheck.check_type(arg, executor_value_base.ExecutorValue)
    type_analysis.check_federated_type(arg.type_signature,
                                       placement=placement_literals.SERVER,
                                       all_equal=True)
    value = await arg.compute()
    type_signature = computation_types.FederatedType(
        arg.type_signature.member, placement_literals.CLIENTS, all_equal=True)
    return await executor.create_value(value, type_signature)
    async def compute_federated_mean(
        self, arg: FederatedComposingStrategyValue
    ) -> FederatedComposingStrategyValue:
        type_analysis.check_federated_type(
            arg.type_signature, placement=placement_literals.CLIENTS)
        member_type = arg.type_signature.member

        async def _create_total():
            total = await self.compute_federated_sum(arg)
            total = await total.compute()
            return await self._server_executor.create_value(total, member_type)

        async def _create_factor():
            cardinalities = await self._get_cardinalities()
            count = sum(cardinalities)
            return await executor_utils.embed_tf_scalar_constant(
                self._server_executor, member_type, float(1.0 / count))

        async def _create_multiply_arg():
            total, factor = await asyncio.gather(_create_total(),
                                                 _create_factor())
            return await self._server_executor.create_struct([total, factor])

        multiply_fn, multiply_arg = await asyncio.gather(
            executor_utils.embed_tf_binary_operator(self._server_executor,
                                                    member_type, tf.multiply),
            _create_multiply_arg())
        result = await self._server_executor.create_call(
            multiply_fn, multiply_arg)
        type_signature = type_factory.at_server(member_type)
        return FederatedComposingStrategyValue(result, type_signature)
Пример #3
0
 async def _encrypt_values_on_clients(self, val, sender, receiver):
     ###
     # Case 2: sender=CLIENTS
     #     plaintext: Fed(Tensor, CLIENTS, all_equal=False)
     #     pk_receiver: Fed(Tensor, CLIENTS, all_equal=True)
     #     sk_sender: Fed(Tensor, CLIENTS, all_equal=False)
     #   Returns:
     #     encrypted_values: Fed(Tensor, CLIENTS, all_equal=False)
     ###
     ### Check proper key placement
     sk_sender = self.key_references.get_secret_key(sender)
     pk_receiver = self.key_references.get_public_key(receiver)
     type_analysis.check_federated_type(sk_sender.type_signature,
                                        placement=sender)
     assert sk_sender.type_signature.placement is sender
     assert pk_receiver.type_signature.placement is sender
     ### Check placement cardinalities
     snd_children = self.strategy._get_child_executors(sender)
     rcv_children = self.strategy._get_child_executors(receiver)
     py_typecheck.check_len(rcv_children, 1)
     ### Check value cardinalities
     type_analysis.check_federated_type(val.type_signature,
                                        placement=sender)
     federated_value_internals = [
         val.internal_representation, pk_receiver.internal_representation,
         sk_sender.internal_representation
     ]
     for v in federated_value_internals:
         py_typecheck.check_len(v, len(snd_children))
     ### Materialize encryptor function definition & type spec
     input_type = val.type_signature.member
     self._input_type_cache = input_type
     pk_rcv_type = pk_receiver.type_signature.member
     sk_snd_type = sk_sender.type_signature.member
     pk_element_type = pk_rcv_type
     encryptor_arg_spec = (input_type, pk_element_type, sk_snd_type)
     encryptor_proto, encryptor_type = utils.materialize_computation_from_cache(
         sodium_comp.make_encryptor, self._encryptor_cache,
         encryptor_arg_spec)
     ### Encrypt values and return them
     encryptor_fns = asyncio.gather(*[
         snd_child.create_value(encryptor_proto, encryptor_type)
         for snd_child in snd_children
     ])
     encryptor_args = asyncio.gather(*[
         snd_child.create_struct([v, pk, sk]) for v, pk, sk, snd_child in
         zip(*federated_value_internals, snd_children)
     ])
     encryptor_fns, encryptor_args = await asyncio.gather(
         encryptor_fns, encryptor_args)
     encrypted_values = [
         snd_child.create_call(encryptor, arg) for encryptor, arg, snd_child
         in zip(encryptor_fns, encryptor_args, snd_children)
     ]
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         await asyncio.gather(*encrypted_values),
         tff.FederatedType(encryptor_type.result,
                           sender,
                           all_equal=val.type_signature.all_equal))
Пример #4
0
 async def _decrypt_values_on_clients(self, val, sender, receiver):
     ### Check proper key placement
     pk_sender = self.key_references.get_public_key(sender)
     sk_receiver = self.key_references.get_secret_key(receiver)
     type_analysis.check_federated_type(pk_sender.type_signature,
                                        placement=receiver)
     type_analysis.check_federated_type(sk_receiver.type_signature,
                                        placement=receiver)
     pk_snd_type = pk_sender.type_signature.member
     sk_rcv_type = sk_receiver.type_signature.member
     ### Check value cardinalities
     rcv_children = self.strategy._get_child_executors(receiver)
     federated_value_internals = [
         val.internal_representation, pk_sender.internal_representation,
         sk_receiver.internal_representation
     ]
     for fv in federated_value_internals:
         py_typecheck.check_len(fv, len(rcv_children))
     ### Materialize decryptor type_spec & function definition
     input_type = val.type_signature.member
     #   input_type[0] is a tff.TensorType, thus input_type represents the
     #   tuple needed for a single value to be decrypted.
     py_typecheck.check_type(input_type[0], tff.TensorType)
     py_typecheck.check_type(pk_snd_type, tff.TensorType)
     input_element_type = input_type
     pk_element_type = pk_snd_type
     decryptor_arg_spec = (input_element_type, pk_element_type, sk_rcv_type)
     decryptor_proto, decryptor_type = utils.materialize_computation_from_cache(
         sodium_comp.make_decryptor,
         self._decryptor_cache,
         decryptor_arg_spec,
         orig_tensor_dtype=self._input_type_cache.dtype)
     ### Decrypt values and return them
     decryptor_fns = asyncio.gather(*[
         rcv_child.create_value(decryptor_proto, decryptor_type)
         for rcv_child in rcv_children
     ])
     decryptor_args = asyncio.gather(*[
         rcv_child.create_struct([v, pk, sk]) for v, pk, sk, rcv_child in
         zip(*federated_value_internals, rcv_children)
     ])
     decryptor_fns, decryptor_args = await asyncio.gather(
         decryptor_fns, decryptor_args)
     decrypted_values = [
         rcv_child.create_call(decryptor, arg) for decryptor, arg, rcv_child
         in zip(decryptor_fns, decryptor_args, rcv_children)
     ]
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         await asyncio.gather(*decrypted_values),
         tff.FederatedType(decryptor_type.result,
                           receiver,
                           all_equal=val.type_signature.all_equal))
 async def compute_federated_sum(
     self, arg: FederatedComposingStrategyValue
 ) -> FederatedComposingStrategyValue:
     type_analysis.check_federated_type(
         arg.type_signature, placement=placement_literals.CLIENTS)
     id_comp, id_type = tensorflow_computation_factory.create_identity(
         arg.type_signature.member)
     zero, plus, identity = await asyncio.gather(
         executor_utils.embed_tf_constant(self._executor,
                                          arg.type_signature.member, 0),
         executor_utils.embed_tf_binary_operator(self._executor,
                                                 arg.type_signature.member,
                                                 tf.add),
         self._executor.create_value(id_comp, id_type))
     aggregate_args = await self._executor.create_struct(
         [arg, zero, plus, plus, identity])
     return await self.compute_federated_aggregate(aggregate_args)
 async def _compute_intrinsic_federated_sum(self, arg):
     type_analysis.check_federated_type(
         arg.type_signature, placement=placement_literals.CLIENTS)
     zero, plus, identity = tuple(await asyncio.gather(*[
         executor_utils.embed_tf_scalar_constant(
             self, arg.type_signature.member, 0),
         executor_utils.embed_tf_binary_operator(
             self, arg.type_signature.member, tf.add),
         self.create_value(
             computation_factory.create_lambda_identity(
                 arg.type_signature.member),
             type_factory.unary_op(arg.type_signature.member))
     ]))
     aggregate_args = await self.create_tuple(
         [arg, zero, plus, plus, identity])
     return await self._compute_intrinsic_federated_aggregate(aggregate_args
                                                              )
Пример #7
0
 async def _compute_intrinsic_federated_collect(self, arg):
     py_typecheck.check_type(arg.type_signature,
                             computation_types.FederatedType)
     type_analysis.check_federated_type(
         arg.type_signature, placement=placement_literals.CLIENTS)
     val = arg.internal_representation
     py_typecheck.check_type(val, list)
     member_type = arg.type_signature.member
     child = self._target_executors[placement_literals.SERVER][0]
     collected_items = await child.create_value(
         await asyncio.gather(*[v.compute() for v in val]),
         computation_types.SequenceType(member_type))
     return FederatingExecutorValue(
         [collected_items],
         computation_types.FederatedType(
             computation_types.SequenceType(member_type),
             placement_literals.SERVER,
             all_equal=True))
 async def compute_federated_zip_at_server(
     self,
     arg: FederatedComposingStrategyValue) -> FederatedComposingStrategyValue:
   py_typecheck.check_type(arg.type_signature, computation_types.StructType)
   py_typecheck.check_len(arg.type_signature, 2)
   py_typecheck.check_type(arg.internal_representation, structure.Struct)
   py_typecheck.check_len(arg.internal_representation, 2)
   for n in [0, 1]:
     type_analysis.check_federated_type(
         arg.type_signature[n],
         placement=placement_literals.SERVER,
         all_equal=True)
   return FederatedComposingStrategyValue(
       await self._server_executor.create_struct(
           [arg.internal_representation[n] for n in [0, 1]]),
       computation_types.at_server(
           computation_types.StructType(
               [arg.type_signature[0].member, arg.type_signature[1].member])))
 async def compute_federated_apply(
     self,
     arg: FederatedComposingStrategyValue) -> FederatedComposingStrategyValue:
   py_typecheck.check_type(arg.internal_representation, structure.Struct)
   py_typecheck.check_len(arg.internal_representation, 2)
   fn_type = arg.type_signature[0]
   py_typecheck.check_type(fn_type, computation_types.FunctionType)
   val_type = arg.type_signature[1]
   type_analysis.check_federated_type(
       val_type, fn_type.parameter, placement_literals.SERVER, all_equal=True)
   fn = arg.internal_representation[0]
   py_typecheck.check_type(fn, pb.Computation)
   val = arg.internal_representation[1]
   py_typecheck.check_type(val, executor_value_base.ExecutorValue)
   return FederatedComposingStrategyValue(
       await self._server_executor.create_call(
           await self._server_executor.create_value(fn, fn_type), val),
       computation_types.at_server(fn_type.result))
    async def compute_federated_zip_at_clients(
        self, arg: FederatedComposingStrategyValue
    ) -> FederatedComposingStrategyValue:
        py_typecheck.check_type(arg.type_signature,
                                computation_types.StructType)
        py_typecheck.check_len(arg.type_signature, 2)
        py_typecheck.check_type(arg.internal_representation,
                                anonymous_tuple.AnonymousTuple)
        py_typecheck.check_len(arg.internal_representation, 2)
        keys = [k for k, _ in anonymous_tuple.to_elements(arg.type_signature)]
        vals = [arg.internal_representation[n] for n in [0, 1]]
        types = [arg.type_signature[n] for n in [0, 1]]
        for n in [0, 1]:
            type_analysis.check_federated_type(
                types[n], placement=placement_literals.CLIENTS)
            types[n] = type_factory.at_clients(types[n].member)
            py_typecheck.check_type(vals[n], list)
            py_typecheck.check_len(vals[n], len(self._target_executors))
        item_type = computation_types.StructType([
            ((keys[n], types[n].member) if keys[n] else types[n].member)
            for n in [0, 1]
        ])
        result_type = type_factory.at_clients(item_type)
        zip_type = computation_types.FunctionType(
            computation_types.StructType([
                ((keys[n], types[n]) if keys[n] else types[n]) for n in [0, 1]
            ]), result_type)
        zip_comp = executor_utils.create_intrinsic_comp(
            intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip_type)

        async def _child_fn(ex, x, y):
            py_typecheck.check_type(x, executor_value_base.ExecutorValue)
            py_typecheck.check_type(y, executor_value_base.ExecutorValue)
            return await ex.create_call(
                await ex.create_value(zip_comp, zip_type), await
                ex.create_struct(
                    anonymous_tuple.AnonymousTuple([(keys[0], x),
                                                    (keys[1], y)])))

        result = await asyncio.gather(*[
            _child_fn(c, x, y)
            for c, x, y in zip(self._target_executors, vals[0], vals[1])
        ])
        return FederatedComposingStrategyValue(result, result_type)
Пример #11
0
 async def _compute_intrinsic_federated_zip_at_server(self, arg):
     py_typecheck.check_type(arg.type_signature,
                             computation_types.NamedTupleType)
     py_typecheck.check_len(arg.type_signature, 2)
     py_typecheck.check_type(arg.internal_representation,
                             anonymous_tuple.AnonymousTuple)
     py_typecheck.check_len(arg.internal_representation, 2)
     for n in [0, 1]:
         type_analysis.check_federated_type(
             arg.type_signature[n],
             placement=placement_literals.SERVER,
             all_equal=True)
     return CompositeValue(
         await self._parent_executor.create_tuple(
             [arg.internal_representation[n] for n in [0, 1]]),
         type_factory.at_server(
             computation_types.NamedTupleType([
                 arg.type_signature[0].member, arg.type_signature[1].member
             ])))
Пример #12
0
 async def _compute_intrinsic_federated_apply(self, arg):
     py_typecheck.check_type(arg.internal_representation,
                             anonymous_tuple.AnonymousTuple)
     py_typecheck.check_len(arg.internal_representation, 2)
     fn_type = arg.type_signature[0]
     py_typecheck.check_type(fn_type, computation_types.FunctionType)
     val_type = arg.type_signature[1]
     type_analysis.check_federated_type(val_type,
                                        fn_type.parameter,
                                        placement_literals.SERVER,
                                        all_equal=True)
     fn = arg.internal_representation[0]
     py_typecheck.check_type(fn, pb.Computation)
     val = arg.internal_representation[1]
     py_typecheck.check_type(val, executor_value_base.ExecutorValue)
     return CompositeValue(
         await self._parent_executor.create_call(
             await self._parent_executor.create_value(fn, fn_type), val),
         type_factory.at_server(fn_type.result))
Пример #13
0
 async def compute_federated_secure_sum(self, arg):
     self._check_arg_is_structure(arg)
     py_typecheck.check_len(arg.internal_representation, 2)
     value_type = arg.type_signature[0]
     type_analysis.check_federated_type(value_type, placement=tff.CLIENTS)
     py_typecheck.check_type(value_type.member, tff.TensorType)
     # Stash input dtype for later
     input_tensor_dtype = value_type.member.dtype
     # Paillier setup phase
     if self._requires_setup:
         await self._paillier_setup()
         self._requires_setup = False
     # Stash input shape, and reshape input tensor to matrix-form
     input_tensor_shape = value_type.member.shape
     if len(input_tensor_shape) != 2:
         clients_value = await self._compute_reshape_on_tensor(
             await self._executor.create_selection(arg, index=0),
             output_shape=[1, input_tensor_shape.num_elements()])
     else:
         clients_value = await self._executor.create_selection(arg, index=0)
     # Encrypt summands on tff.CLIENTS
     encrypted_values = await self._compute_paillier_encryption(
         self.encryption_key_clients, clients_value)
     # Perform Paillier sum on ciphertexts
     encrypted_values = await self._move(encrypted_values, tff.CLIENTS,
                                         paillier_placement.AGGREGATOR)
     encrypted_sum = await self._compute_paillier_sum(
         self.encryption_key_paillier, encrypted_values)
     # Move to server and decrypt the result
     encrypted_sum = await self._move(encrypted_sum,
                                      paillier_placement.AGGREGATOR,
                                      tff.SERVER)
     decrypted_result = await self._compute_paillier_decryption(
         self.decryption_key,
         self.encryption_key_server,
         encrypted_sum,
         export_dtype=input_tensor_dtype)
     return await self._compute_reshape_on_tensor(
         decrypted_result, output_shape=input_tensor_shape.as_list())
Пример #14
0
 def test_raises_type_error(self):
     type_spec = computation_types.FederatedType(tf.int32,
                                                 placement_literals.CLIENTS,
                                                 False)
     type_analysis.check_federated_type(type_spec, tf.int32,
                                        placement_literals.CLIENTS, False)
     type_analysis.check_federated_type(type_spec, tf.int32, None, None)
     type_analysis.check_federated_type(type_spec, None,
                                        placement_literals.CLIENTS, None)
     type_analysis.check_federated_type(type_spec, None, None, False)
     self.assertRaises(TypeError, type_analysis.check_federated_type,
                       type_spec, tf.bool, None, None)
     self.assertRaises(TypeError, type_analysis.check_federated_type,
                       type_spec, None, placement_literals.SERVER, None)
     self.assertRaises(TypeError, type_analysis.check_federated_type,
                       type_spec, None, None, True)
Пример #15
0
 async def _decrypt_values_on_singleton(self, val, sender, receiver):
     ### Check proper key placement
     pk_sender = self.key_references.get_public_key(sender)
     sk_receiver = self.key_references.get_secret_key(receiver)
     type_analysis.check_federated_type(pk_sender.type_signature,
                                        placement=receiver)
     type_analysis.check_federated_type(sk_receiver.type_signature,
                                        placement=receiver)
     pk_snd_type = pk_sender.type_signature.member
     sk_rcv_type = sk_receiver.type_signature.member
     ### Check placement cardinalities
     snd_children = self.strategy._get_child_executors(sender)
     rcv_children = self.strategy._get_child_executors(receiver)
     py_typecheck.check_len(rcv_children, 1)
     rcv_child = rcv_children[0]
     ### Check value cardinalities
     py_typecheck.check_len(pk_sender.internal_representation,
                            len(snd_children))
     py_typecheck.check_len(sk_receiver.internal_representation, 1)
     ### Materialize decryptor type_spec & function definition
     py_typecheck.check_type(val.type_signature, tff.StructType)
     type_analysis.check_federated_type(val.type_signature[0],
                                        placement=receiver,
                                        all_equal=True)
     input_type = val.type_signature[0].member
     #   each input_type is a tuple needed for one value to be decrypted
     py_typecheck.check_type(input_type, tff.StructType)
     py_typecheck.check_type(pk_snd_type, tff.StructType)
     py_typecheck.check_len(val.type_signature, len(pk_snd_type))
     input_element_type = input_type
     pk_element_type = pk_snd_type[0]
     decryptor_arg_spec = (input_element_type, pk_element_type, sk_rcv_type)
     decryptor_proto, decryptor_type = utils.materialize_computation_from_cache(
         sodium_comp.make_decryptor,
         self._decryptor_cache,
         decryptor_arg_spec,
         orig_tensor_dtype=self._input_type_cache.dtype)
     ### Decrypt values and return them
     vals = val.internal_representation
     sk = sk_receiver.internal_representation[0]
     decryptor_fn = await rcv_child.create_value(decryptor_proto,
                                                 decryptor_type)
     decryptor_args = await asyncio.gather(*[
         rcv_child.create_struct([v, pk, sk])
         for v, pk in zip(vals, pk_sender.internal_representation)
     ])
     decrypted_values = await asyncio.gather(*[
         rcv_child.create_call(decryptor_fn, arg) for arg in decryptor_args
     ])
     decrypted_value_types = [decryptor_type.result] * len(decrypted_values)
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         structure.from_container(decrypted_values),
         tff.StructType([
             tff.FederatedType(dvt, receiver, all_equal=True)
             for dvt in decrypted_value_types
         ]))
Пример #16
0
 async def _encrypt_values_on_singleton(self, val, sender, receiver):
     ###
     # we can safely assume  sender has cardinality=1 when receiver is CLIENTS
     ###
     # Case 1: receiver=CLIENTS
     #     plaintext: Fed(Tensor, sender, all_equal=True)
     #     pk_receiver: Fed(Tuple(Tensor), sender, all_equal=True)
     #     sk_sender: Fed(Tensor, sender, all_equal=True)
     #   Returns:
     #     encrypted_values: Tuple(Fed(Tensor, sender, all_equal=True))
     ###
     ### Check proper key placement
     sk_sender = self.key_references.get_secret_key(sender)
     pk_receiver = self.key_references.get_public_key(receiver)
     type_analysis.check_federated_type(sk_sender.type_signature,
                                        placement=sender)
     assert sk_sender.type_signature.placement is sender
     assert pk_receiver.type_signature.placement is sender
     ### Check placement cardinalities
     rcv_children = self.strategy._get_child_executors(receiver)
     snd_children = self.strategy._get_child_executors(sender)
     py_typecheck.check_len(snd_children, 1)
     snd_child = snd_children[0]
     ### Check value cardinalities
     type_analysis.check_federated_type(val.type_signature,
                                        placement=sender)
     py_typecheck.check_len(val.internal_representation, 1)
     py_typecheck.check_type(pk_receiver.type_signature.member,
                             tff.StructType)
     py_typecheck.check_len(pk_receiver.internal_representation,
                            len(rcv_children))
     py_typecheck.check_len(sk_sender.internal_representation, 1)
     ### Materialize encryptor function definition & type spec
     input_type = val.type_signature.member
     self._input_type_cache = input_type
     pk_rcv_type = pk_receiver.type_signature.member
     sk_snd_type = sk_sender.type_signature.member
     pk_element_type = pk_rcv_type[0]
     encryptor_arg_spec = (input_type, pk_element_type, sk_snd_type)
     encryptor_proto, encryptor_type = utils.materialize_computation_from_cache(
         sodium_comp.make_encryptor, self._encryptor_cache,
         encryptor_arg_spec)
     ### Prepare encryption arguments
     v = val.internal_representation[0]
     sk = sk_sender.internal_representation[0]
     ### Encrypt values and return them
     encryptor_fn = await snd_child.create_value(encryptor_proto,
                                                 encryptor_type)
     encryptor_args = await asyncio.gather(*[
         snd_child.create_struct([v, this_pk, sk])
         for this_pk in pk_receiver.internal_representation
     ])
     encrypted_values = await asyncio.gather(*[
         snd_child.create_call(encryptor_fn, arg) for arg in encryptor_args
     ])
     encrypted_value_types = [encryptor_type.result] * len(encrypted_values)
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         structure.from_container(encrypted_values),
         tff.StructType([
             tff.FederatedType(evt, sender, all_equal=False)
             for evt in encrypted_value_types
         ]))