async def _compute_paillier_decryption(
         self, decryption_key: federated_resolving_strategy.
     FederatedResolvingStrategyValue,
         encryption_key: federated_resolving_strategy.
     FederatedResolvingStrategyValue, value: federated_resolving_strategy.
     FederatedResolvingStrategyValue, export_dtype):
     server_child = self._get_child_executors(tff.SERVER, index=0)
     decryptor_arg_spec = (decryption_key.type_signature.member,
                           encryption_key.type_signature.member,
                           value.type_signature.member)
     decryptor_proto, decryptor_type = utils.materialize_computation_from_cache(
         paillier_comp.make_decryptor,
         self._paillier_decryptor_cache,
         arg_spec=decryptor_arg_spec,
         export_dtype=export_dtype)
     decryptor_fn = server_child.create_value(decryptor_proto,
                                              decryptor_type)
     decryptor_arg = server_child.create_struct(
         (decryption_key.internal_representation,
          encryption_key.internal_representation,
          value.internal_representation))
     decryptor_fn, decryptor_arg = await asyncio.gather(
         decryptor_fn, decryptor_arg)
     decrypted_value = await server_child.create_call(
         decryptor_fn, decryptor_arg)
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         [decrypted_value],
         tff.FederatedType(decryptor_type.result, tff.SERVER, True))
Esempio n. 2
0
 async def _encrypt_values_on_clients(self, val, sender, receiver):
     ###
     # Case 2: sender=CLIENTS
     #     plaintext: Fed(Tensor, CLIENTS, all_equal=False)
     #     pk_receiver: Fed(Tensor, CLIENTS, all_equal=True)
     #     sk_sender: Fed(Tensor, CLIENTS, all_equal=False)
     #   Returns:
     #     encrypted_values: Fed(Tensor, CLIENTS, all_equal=False)
     ###
     ### Check proper key placement
     sk_sender = self.key_references.get_secret_key(sender)
     pk_receiver = self.key_references.get_public_key(receiver)
     type_analysis.check_federated_type(sk_sender.type_signature,
                                        placement=sender)
     assert sk_sender.type_signature.placement is sender
     assert pk_receiver.type_signature.placement is sender
     ### Check placement cardinalities
     snd_children = self.strategy._get_child_executors(sender)
     rcv_children = self.strategy._get_child_executors(receiver)
     py_typecheck.check_len(rcv_children, 1)
     ### Check value cardinalities
     type_analysis.check_federated_type(val.type_signature,
                                        placement=sender)
     federated_value_internals = [
         val.internal_representation, pk_receiver.internal_representation,
         sk_sender.internal_representation
     ]
     for v in federated_value_internals:
         py_typecheck.check_len(v, len(snd_children))
     ### Materialize encryptor function definition & type spec
     input_type = val.type_signature.member
     self._input_type_cache = input_type
     pk_rcv_type = pk_receiver.type_signature.member
     sk_snd_type = sk_sender.type_signature.member
     pk_element_type = pk_rcv_type
     encryptor_arg_spec = (input_type, pk_element_type, sk_snd_type)
     encryptor_proto, encryptor_type = utils.materialize_computation_from_cache(
         sodium_comp.make_encryptor, self._encryptor_cache,
         encryptor_arg_spec)
     ### Encrypt values and return them
     encryptor_fns = asyncio.gather(*[
         snd_child.create_value(encryptor_proto, encryptor_type)
         for snd_child in snd_children
     ])
     encryptor_args = asyncio.gather(*[
         snd_child.create_struct([v, pk, sk]) for v, pk, sk, snd_child in
         zip(*federated_value_internals, snd_children)
     ])
     encryptor_fns, encryptor_args = await asyncio.gather(
         encryptor_fns, encryptor_args)
     encrypted_values = [
         snd_child.create_call(encryptor, arg) for encryptor, arg, snd_child
         in zip(encryptor_fns, encryptor_args, snd_children)
     ]
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         await asyncio.gather(*encrypted_values),
         tff.FederatedType(encryptor_type.result,
                           sender,
                           all_equal=val.type_signature.all_equal))
Esempio n. 3
0
 async def _decrypt_values_on_singleton(self, val, sender, receiver):
     ### Check proper key placement
     pk_sender = self.key_references.get_public_key(sender)
     sk_receiver = self.key_references.get_secret_key(receiver)
     type_analysis.check_federated_type(pk_sender.type_signature,
                                        placement=receiver)
     type_analysis.check_federated_type(sk_receiver.type_signature,
                                        placement=receiver)
     pk_snd_type = pk_sender.type_signature.member
     sk_rcv_type = sk_receiver.type_signature.member
     ### Check placement cardinalities
     snd_children = self.strategy._get_child_executors(sender)
     rcv_children = self.strategy._get_child_executors(receiver)
     py_typecheck.check_len(rcv_children, 1)
     rcv_child = rcv_children[0]
     ### Check value cardinalities
     py_typecheck.check_len(pk_sender.internal_representation,
                            len(snd_children))
     py_typecheck.check_len(sk_receiver.internal_representation, 1)
     ### Materialize decryptor type_spec & function definition
     py_typecheck.check_type(val.type_signature, tff.StructType)
     type_analysis.check_federated_type(val.type_signature[0],
                                        placement=receiver,
                                        all_equal=True)
     input_type = val.type_signature[0].member
     #   each input_type is a tuple needed for one value to be decrypted
     py_typecheck.check_type(input_type, tff.StructType)
     py_typecheck.check_type(pk_snd_type, tff.StructType)
     py_typecheck.check_len(val.type_signature, len(pk_snd_type))
     input_element_type = input_type
     pk_element_type = pk_snd_type[0]
     decryptor_arg_spec = (input_element_type, pk_element_type, sk_rcv_type)
     decryptor_proto, decryptor_type = utils.materialize_computation_from_cache(
         sodium_comp.make_decryptor,
         self._decryptor_cache,
         decryptor_arg_spec,
         orig_tensor_dtype=self._input_type_cache.dtype)
     ### Decrypt values and return them
     vals = val.internal_representation
     sk = sk_receiver.internal_representation[0]
     decryptor_fn = await rcv_child.create_value(decryptor_proto,
                                                 decryptor_type)
     decryptor_args = await asyncio.gather(*[
         rcv_child.create_struct([v, pk, sk])
         for v, pk in zip(vals, pk_sender.internal_representation)
     ])
     decrypted_values = await asyncio.gather(*[
         rcv_child.create_call(decryptor_fn, arg) for arg in decryptor_args
     ])
     decrypted_value_types = [decryptor_type.result] * len(decrypted_values)
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         structure.from_container(decrypted_values),
         tff.StructType([
             tff.FederatedType(dvt, receiver, all_equal=True)
             for dvt in decrypted_value_types
         ]))
Esempio n. 4
0
 async def _decrypt_values_on_clients(self, val, sender, receiver):
     ### Check proper key placement
     pk_sender = self.key_references.get_public_key(sender)
     sk_receiver = self.key_references.get_secret_key(receiver)
     type_analysis.check_federated_type(pk_sender.type_signature,
                                        placement=receiver)
     type_analysis.check_federated_type(sk_receiver.type_signature,
                                        placement=receiver)
     pk_snd_type = pk_sender.type_signature.member
     sk_rcv_type = sk_receiver.type_signature.member
     ### Check value cardinalities
     rcv_children = self.strategy._get_child_executors(receiver)
     federated_value_internals = [
         val.internal_representation, pk_sender.internal_representation,
         sk_receiver.internal_representation
     ]
     for fv in federated_value_internals:
         py_typecheck.check_len(fv, len(rcv_children))
     ### Materialize decryptor type_spec & function definition
     input_type = val.type_signature.member
     #   input_type[0] is a tff.TensorType, thus input_type represents the
     #   tuple needed for a single value to be decrypted.
     py_typecheck.check_type(input_type[0], tff.TensorType)
     py_typecheck.check_type(pk_snd_type, tff.TensorType)
     input_element_type = input_type
     pk_element_type = pk_snd_type
     decryptor_arg_spec = (input_element_type, pk_element_type, sk_rcv_type)
     decryptor_proto, decryptor_type = utils.materialize_computation_from_cache(
         sodium_comp.make_decryptor,
         self._decryptor_cache,
         decryptor_arg_spec,
         orig_tensor_dtype=self._input_type_cache.dtype)
     ### Decrypt values and return them
     decryptor_fns = asyncio.gather(*[
         rcv_child.create_value(decryptor_proto, decryptor_type)
         for rcv_child in rcv_children
     ])
     decryptor_args = asyncio.gather(*[
         rcv_child.create_struct([v, pk, sk]) for v, pk, sk, rcv_child in
         zip(*federated_value_internals, rcv_children)
     ])
     decryptor_fns, decryptor_args = await asyncio.gather(
         decryptor_fns, decryptor_args)
     decrypted_values = [
         rcv_child.create_call(decryptor, arg) for decryptor, arg, rcv_child
         in zip(decryptor_fns, decryptor_args, rcv_children)
     ]
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         await asyncio.gather(*decrypted_values),
         tff.FederatedType(decryptor_type.result,
                           receiver,
                           all_equal=val.type_signature.all_equal))
 async def _compute_reshape_on_tensor(self, tensor, output_shape):
     tensor_type = tensor.type_signature.member
     shape_type = type_conversions.infer_type(output_shape)
     reshaper_proto, reshaper_type = utils.materialize_computation_from_cache(
         paillier_comp.make_reshape_tensor,
         self._reshape_function_cache,
         arg_spec=(tensor_type, ),
         output_shape=output_shape)
     tensor_placement = tensor.type_signature.placement
     children = self._get_child_executors(tensor_placement)
     py_typecheck.check_len(tensor.internal_representation, len(children))
     reshaper_fns = await asyncio.gather(*[
         ex.create_value(reshaper_proto, reshaper_type) for ex in children
     ])
     reshaped_tensors = await asyncio.gather(*[
         ex.create_call(fn, arg) for ex, fn, arg in zip(
             children, reshaper_fns, tensor.internal_representation)
     ])
     output_tensor_spec = tff.FederatedType(
         tff.TensorType(tensor_type.dtype, output_shape), tensor_placement,
         tensor.type_signature.all_equal)
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         reshaped_tensors, output_tensor_spec)
Esempio n. 6
0
 async def _encrypt_values_on_singleton(self, val, sender, receiver):
     ###
     # we can safely assume  sender has cardinality=1 when receiver is CLIENTS
     ###
     # Case 1: receiver=CLIENTS
     #     plaintext: Fed(Tensor, sender, all_equal=True)
     #     pk_receiver: Fed(Tuple(Tensor), sender, all_equal=True)
     #     sk_sender: Fed(Tensor, sender, all_equal=True)
     #   Returns:
     #     encrypted_values: Tuple(Fed(Tensor, sender, all_equal=True))
     ###
     ### Check proper key placement
     sk_sender = self.key_references.get_secret_key(sender)
     pk_receiver = self.key_references.get_public_key(receiver)
     type_analysis.check_federated_type(sk_sender.type_signature,
                                        placement=sender)
     assert sk_sender.type_signature.placement is sender
     assert pk_receiver.type_signature.placement is sender
     ### Check placement cardinalities
     rcv_children = self.strategy._get_child_executors(receiver)
     snd_children = self.strategy._get_child_executors(sender)
     py_typecheck.check_len(snd_children, 1)
     snd_child = snd_children[0]
     ### Check value cardinalities
     type_analysis.check_federated_type(val.type_signature,
                                        placement=sender)
     py_typecheck.check_len(val.internal_representation, 1)
     py_typecheck.check_type(pk_receiver.type_signature.member,
                             tff.StructType)
     py_typecheck.check_len(pk_receiver.internal_representation,
                            len(rcv_children))
     py_typecheck.check_len(sk_sender.internal_representation, 1)
     ### Materialize encryptor function definition & type spec
     input_type = val.type_signature.member
     self._input_type_cache = input_type
     pk_rcv_type = pk_receiver.type_signature.member
     sk_snd_type = sk_sender.type_signature.member
     pk_element_type = pk_rcv_type[0]
     encryptor_arg_spec = (input_type, pk_element_type, sk_snd_type)
     encryptor_proto, encryptor_type = utils.materialize_computation_from_cache(
         sodium_comp.make_encryptor, self._encryptor_cache,
         encryptor_arg_spec)
     ### Prepare encryption arguments
     v = val.internal_representation[0]
     sk = sk_sender.internal_representation[0]
     ### Encrypt values and return them
     encryptor_fn = await snd_child.create_value(encryptor_proto,
                                                 encryptor_type)
     encryptor_args = await asyncio.gather(*[
         snd_child.create_struct([v, this_pk, sk])
         for this_pk in pk_receiver.internal_representation
     ])
     encrypted_values = await asyncio.gather(*[
         snd_child.create_call(encryptor_fn, arg) for arg in encryptor_args
     ])
     encrypted_value_types = [encryptor_type.result] * len(encrypted_values)
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         structure.from_container(encrypted_values),
         tff.StructType([
             tff.FederatedType(evt, sender, all_equal=False)
             for evt in encrypted_value_types
         ]))