コード例 #1
0
  async def compute_federated_secure_sum(
      self, arg: federated_resolving_strategy.FederatedResolvingStrategyValue
  ) -> federated_resolving_strategy.FederatedResolvingStrategyValue:
    py_typecheck.check_type(arg.internal_representation, structure.Struct)
    py_typecheck.check_len(arg.internal_representation, 2)
    logging.warning(
        'The implementation of the `tff.federated_secure_sum` intrinsic '
        'provided by the `tff.backends.test` runtime uses no cryptography.')

    server_ex = self._target_executors[placement_literals.SERVER][0]
    value = federated_resolving_strategy.FederatedResolvingStrategyValue(
        arg.internal_representation[0], arg.type_signature[0])
    bitwidth = await arg.internal_representation[1].compute()
    bitwidth_type = arg.type_signature[1]
    sum_result, mask, fn = await asyncio.gather(
        self.compute_federated_sum(value),
        executor_utils.embed_tf_constant(self._executor, bitwidth_type,
                                         2**bitwidth - 1),
        executor_utils.embed_tf_binary_operator(self._executor, bitwidth_type,
                                                tf.math.mod))
    fn_arg = await server_ex.create_struct([
        sum_result.internal_representation[0],
        mask.internal_representation,
    ])
    fn_arg_type = computation_types.FederatedType(
        fn_arg.type_signature, placement_literals.SERVER, all_equal=True)
    arg = federated_resolving_strategy.FederatedResolvingStrategyValue(
        structure.Struct([
            (None, fn.internal_representation),
            (None, [fn_arg]),
        ]), computation_types.StructType([fn.type_signature, fn_arg_type]))
    return await self.compute_federated_map_all_equal(arg)
コード例 #2
0
    async def _generate_keys(self, key_owner):
        py_typecheck.check_type(key_owner, placement_literals.PlacementLiteral)
        executors = self.strategy._get_child_executors(key_owner)
        if self._key_generator is None:
            self._key_generator = sodium_comp.make_keygen()
        keygen, keygen_type = utils.lift_to_computation_spec(
            self._key_generator)
        pk_vals, sk_vals = [], []

        async def keygen_call(child):
            return await child.create_call(await child.create_value(
                keygen, keygen_type))

        key_generators = await asyncio.gather(
            *[keygen_call(executor) for executor in executors])
        public_keys = asyncio.gather(*[
            executor.create_selection(key_generator, 0)
            for executor, key_generator in zip(executors, key_generators)
        ])
        secret_keys = asyncio.gather(*[
            executor.create_selection(key_generator, 1)
            for executor, key_generator in zip(executors, key_generators)
        ])
        pk_vals, sk_vals = await asyncio.gather(public_keys, secret_keys)
        pk_type = pk_vals[0].type_signature
        sk_type = sk_vals[0].type_signature
        # all_equal whenever owner is non-CLIENTS singleton placement
        val_all_equal = len(executors) == 1 and key_owner != tff.CLIENTS
        pk_fed_val = federated_resolving_strategy.FederatedResolvingStrategyValue(
            pk_vals, tff.FederatedType(pk_type, key_owner, val_all_equal))
        sk_fed_val = federated_resolving_strategy.FederatedResolvingStrategyValue(
            sk_vals, tff.FederatedType(sk_type, key_owner, val_all_equal))
        self.key_references.update_keys(key_owner,
                                        public_key=pk_fed_val,
                                        secret_key=sk_fed_val)
コード例 #3
0
 async def _compute_paillier_decryption(
         self, decryption_key: federated_resolving_strategy.
     FederatedResolvingStrategyValue,
         encryption_key: federated_resolving_strategy.
     FederatedResolvingStrategyValue, value: federated_resolving_strategy.
     FederatedResolvingStrategyValue, export_dtype):
     server_child = self._get_child_executors(tff.SERVER, index=0)
     decryptor_arg_spec = (decryption_key.type_signature.member,
                           encryption_key.type_signature.member,
                           value.type_signature.member)
     decryptor_proto, decryptor_type = utils.materialize_computation_from_cache(
         paillier_comp.make_decryptor,
         self._paillier_decryptor_cache,
         arg_spec=decryptor_arg_spec,
         export_dtype=export_dtype)
     decryptor_fn = server_child.create_value(decryptor_proto,
                                              decryptor_type)
     decryptor_arg = server_child.create_struct(
         (decryption_key.internal_representation,
          encryption_key.internal_representation,
          value.internal_representation))
     decryptor_fn, decryptor_arg = await asyncio.gather(
         decryptor_fn, decryptor_arg)
     decrypted_value = await server_child.create_call(
         decryptor_fn, decryptor_arg)
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         [decrypted_value],
         tff.FederatedType(decryptor_type.result, tff.SERVER, True))
コード例 #4
0
 async def _encrypt_values_on_clients(self, val, sender, receiver):
     ###
     # Case 2: sender=CLIENTS
     #     plaintext: Fed(Tensor, CLIENTS, all_equal=False)
     #     pk_receiver: Fed(Tensor, CLIENTS, all_equal=True)
     #     sk_sender: Fed(Tensor, CLIENTS, all_equal=False)
     #   Returns:
     #     encrypted_values: Fed(Tensor, CLIENTS, all_equal=False)
     ###
     ### Check proper key placement
     sk_sender = self.key_references.get_secret_key(sender)
     pk_receiver = self.key_references.get_public_key(receiver)
     type_analysis.check_federated_type(sk_sender.type_signature,
                                        placement=sender)
     assert sk_sender.type_signature.placement is sender
     assert pk_receiver.type_signature.placement is sender
     ### Check placement cardinalities
     snd_children = self.strategy._get_child_executors(sender)
     rcv_children = self.strategy._get_child_executors(receiver)
     py_typecheck.check_len(rcv_children, 1)
     ### Check value cardinalities
     type_analysis.check_federated_type(val.type_signature,
                                        placement=sender)
     federated_value_internals = [
         val.internal_representation, pk_receiver.internal_representation,
         sk_sender.internal_representation
     ]
     for v in federated_value_internals:
         py_typecheck.check_len(v, len(snd_children))
     ### Materialize encryptor function definition & type spec
     input_type = val.type_signature.member
     self._input_type_cache = input_type
     pk_rcv_type = pk_receiver.type_signature.member
     sk_snd_type = sk_sender.type_signature.member
     pk_element_type = pk_rcv_type
     encryptor_arg_spec = (input_type, pk_element_type, sk_snd_type)
     encryptor_proto, encryptor_type = utils.materialize_computation_from_cache(
         sodium_comp.make_encryptor, self._encryptor_cache,
         encryptor_arg_spec)
     ### Encrypt values and return them
     encryptor_fns = asyncio.gather(*[
         snd_child.create_value(encryptor_proto, encryptor_type)
         for snd_child in snd_children
     ])
     encryptor_args = asyncio.gather(*[
         snd_child.create_struct([v, pk, sk]) for v, pk, sk, snd_child in
         zip(*federated_value_internals, snd_children)
     ])
     encryptor_fns, encryptor_args = await asyncio.gather(
         encryptor_fns, encryptor_args)
     encrypted_values = [
         snd_child.create_call(encryptor, arg) for encryptor, arg, snd_child
         in zip(encryptor_fns, encryptor_args, snd_children)
     ]
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         await asyncio.gather(*encrypted_values),
         tff.FederatedType(encryptor_type.result,
                           sender,
                           all_equal=val.type_signature.all_equal))
コード例 #5
0
  def test_raises_type_error_with_unembedded_federated_type(self):
    value = [10.0, 11.0, 12.0]
    type_signature = type_factory.at_clients(tf.float32)
    value = federated_resolving_strategy.FederatedResolvingStrategyValue(
        value, type_signature)

    with self.assertRaises(TypeError):
      self.run_sync(value.compute())
コード例 #6
0
  def test_raises_runtime_error_with_unsupported_value_or_type(self):
    value = 10.0
    type_signature = computation_types.TensorType(tf.float32)
    value = federated_resolving_strategy.FederatedResolvingStrategyValue(
        value, type_signature)

    with self.assertRaises(RuntimeError):
      self.run_sync(value.compute())
コード例 #7
0
  def test_returns_value_with_embedded_value(self):
    value = eager_tf_executor.EagerValue(10.0, None, tf.float32)
    type_signature = computation_types.TensorType(tf.float32)
    value = federated_resolving_strategy.FederatedResolvingStrategyValue(
        value, type_signature)

    result = self.run_sync(value.compute())

    self.assertEqual(result, 10.0)
コード例 #8
0
  def test_returns_value_with_federated_type_at_server(self):
    value = [eager_tf_executor.EagerValue(10.0, None, tf.float32)]
    type_signature = type_factory.at_server(tf.float32)
    value = federated_resolving_strategy.FederatedResolvingStrategyValue(
        value, type_signature)

    result = self.run_sync(value.compute())

    self.assertEqual(result, 10.0)
コード例 #9
0
 async def _decrypt_values_on_singleton(self, val, sender, receiver):
     ### Check proper key placement
     pk_sender = self.key_references.get_public_key(sender)
     sk_receiver = self.key_references.get_secret_key(receiver)
     type_analysis.check_federated_type(pk_sender.type_signature,
                                        placement=receiver)
     type_analysis.check_federated_type(sk_receiver.type_signature,
                                        placement=receiver)
     pk_snd_type = pk_sender.type_signature.member
     sk_rcv_type = sk_receiver.type_signature.member
     ### Check placement cardinalities
     snd_children = self.strategy._get_child_executors(sender)
     rcv_children = self.strategy._get_child_executors(receiver)
     py_typecheck.check_len(rcv_children, 1)
     rcv_child = rcv_children[0]
     ### Check value cardinalities
     py_typecheck.check_len(pk_sender.internal_representation,
                            len(snd_children))
     py_typecheck.check_len(sk_receiver.internal_representation, 1)
     ### Materialize decryptor type_spec & function definition
     py_typecheck.check_type(val.type_signature, tff.StructType)
     type_analysis.check_federated_type(val.type_signature[0],
                                        placement=receiver,
                                        all_equal=True)
     input_type = val.type_signature[0].member
     #   each input_type is a tuple needed for one value to be decrypted
     py_typecheck.check_type(input_type, tff.StructType)
     py_typecheck.check_type(pk_snd_type, tff.StructType)
     py_typecheck.check_len(val.type_signature, len(pk_snd_type))
     input_element_type = input_type
     pk_element_type = pk_snd_type[0]
     decryptor_arg_spec = (input_element_type, pk_element_type, sk_rcv_type)
     decryptor_proto, decryptor_type = utils.materialize_computation_from_cache(
         sodium_comp.make_decryptor,
         self._decryptor_cache,
         decryptor_arg_spec,
         orig_tensor_dtype=self._input_type_cache.dtype)
     ### Decrypt values and return them
     vals = val.internal_representation
     sk = sk_receiver.internal_representation[0]
     decryptor_fn = await rcv_child.create_value(decryptor_proto,
                                                 decryptor_type)
     decryptor_args = await asyncio.gather(*[
         rcv_child.create_struct([v, pk, sk])
         for v, pk in zip(vals, pk_sender.internal_representation)
     ])
     decrypted_values = await asyncio.gather(*[
         rcv_child.create_call(decryptor_fn, arg) for arg in decryptor_args
     ])
     decrypted_value_types = [decryptor_type.result] * len(decrypted_values)
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         structure.from_container(decrypted_values),
         tff.StructType([
             tff.FederatedType(dvt, receiver, all_equal=True)
             for dvt in decrypted_value_types
         ]))
コード例 #10
0
 async def transfer(self, value):
     _check_value_placement(value, self.placements)
     sender_placement = value.type_signature.placement
     receiver_placement = _get_other_placement(sender_placement,
                                               self.placements)
     sent = await self.send(value, sender_placement, receiver_placement)
     rcv_children = self.strategy._get_child_executors(receiver_placement)
     message = await sent.compute()
     message_type = type_conversions.infer_type(message)
     if receiver_placement is tff.CLIENTS:
         if isinstance(message_type, tff.StructType):
             iterator = zip(rcv_children, message, message_type)
             member_type = message_type[0]
             all_equal = False
         else:
             iterator = zip(rcv_children, itertools.repeat(message),
                            itertools.repeat(message_type))
             member_type = message_type
             all_equal = True
         message_value = federated_resolving_strategy.FederatedResolvingStrategyValue(
             await
             asyncio.gather(*[c.create_value(m, t)
                              for c, m, t in iterator]),
             tff.FederatedType(member_type, receiver_placement, all_equal))
     else:
         rcv_child = rcv_children[0]
         if isinstance(message_type, tff.StructType):
             message_value = federated_resolving_strategy.FederatedResolvingStrategyValue(
                 structure.from_container(await asyncio.gather(*[
                     rcv_child.create_value(m, t)
                     for m, t in zip(message, message_type)
                 ])),
                 tff.StructType([
                     tff.FederatedType(mt, receiver_placement, True)
                     for mt in message_type
                 ]))
         else:
             message_value = federated_resolving_strategy.FederatedResolvingStrategyValue(
                 await rcv_child.create_value(message, message_type),
                 tff.FederatedType(message_type, receiver_placement,
                                   value.type_signature.all_equal))
     return await self.receive(message_value, sender_placement,
                               receiver_placement)
コード例 #11
0
  def test_returns_value_with_federated_type_at_server(self):
    tensor_type = computation_types.TensorType(tf.float32, shape=[])
    value = [eager_tf_executor.EagerValue(10.0, tensor_type)]
    type_signature = computation_types.at_server(tf.float32)
    value = federated_resolving_strategy.FederatedResolvingStrategyValue(
        value, type_signature)

    result = self.run_sync(value.compute())

    self.assertEqual(result, 10.0)
コード例 #12
0
    def test_returns_value_with_federated_type_at_clients_all_equal(self):
        value = [eager_tf_executor.EagerValue(10.0, None, tf.float32)]
        type_signature = computation_types.at_clients(tf.float32,
                                                      all_equal=True)
        value = federated_resolving_strategy.FederatedResolvingStrategyValue(
            value, type_signature)

        result = self.run_sync(value.compute())

        self.assertEqual(result, 10.0)
コード例 #13
0
 async def _decrypt_values_on_clients(self, val, sender, receiver):
     ### Check proper key placement
     pk_sender = self.key_references.get_public_key(sender)
     sk_receiver = self.key_references.get_secret_key(receiver)
     type_analysis.check_federated_type(pk_sender.type_signature,
                                        placement=receiver)
     type_analysis.check_federated_type(sk_receiver.type_signature,
                                        placement=receiver)
     pk_snd_type = pk_sender.type_signature.member
     sk_rcv_type = sk_receiver.type_signature.member
     ### Check value cardinalities
     rcv_children = self.strategy._get_child_executors(receiver)
     federated_value_internals = [
         val.internal_representation, pk_sender.internal_representation,
         sk_receiver.internal_representation
     ]
     for fv in federated_value_internals:
         py_typecheck.check_len(fv, len(rcv_children))
     ### Materialize decryptor type_spec & function definition
     input_type = val.type_signature.member
     #   input_type[0] is a tff.TensorType, thus input_type represents the
     #   tuple needed for a single value to be decrypted.
     py_typecheck.check_type(input_type[0], tff.TensorType)
     py_typecheck.check_type(pk_snd_type, tff.TensorType)
     input_element_type = input_type
     pk_element_type = pk_snd_type
     decryptor_arg_spec = (input_element_type, pk_element_type, sk_rcv_type)
     decryptor_proto, decryptor_type = utils.materialize_computation_from_cache(
         sodium_comp.make_decryptor,
         self._decryptor_cache,
         decryptor_arg_spec,
         orig_tensor_dtype=self._input_type_cache.dtype)
     ### Decrypt values and return them
     decryptor_fns = asyncio.gather(*[
         rcv_child.create_value(decryptor_proto, decryptor_type)
         for rcv_child in rcv_children
     ])
     decryptor_args = asyncio.gather(*[
         rcv_child.create_struct([v, pk, sk]) for v, pk, sk, rcv_child in
         zip(*federated_value_internals, rcv_children)
     ])
     decryptor_fns, decryptor_args = await asyncio.gather(
         decryptor_fns, decryptor_args)
     decrypted_values = [
         rcv_child.create_call(decryptor, arg) for decryptor, arg, rcv_child
         in zip(decryptor_fns, decryptor_args, rcv_children)
     ]
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         await asyncio.gather(*decrypted_values),
         tff.FederatedType(decryptor_type.result,
                           receiver,
                           all_equal=val.type_signature.all_equal))
コード例 #14
0
  def test_returns_value_with_anonymous_tuple_value(self):
    element = eager_tf_executor.EagerValue(10.0, None, tf.float32)
    element_type = computation_types.TensorType(tf.float32)
    names = ['a', 'b', 'c']
    value = anonymous_tuple.AnonymousTuple((n, element) for n in names)
    type_signature = computation_types.NamedTupleType(
        (n, element_type) for n in names)
    value = federated_resolving_strategy.FederatedResolvingStrategyValue(
        value, type_signature)

    result = self.run_sync(value.compute())

    expected_result = anonymous_tuple.AnonymousTuple((n, 10.0) for n in names)
    self.assertEqual(result, expected_result)
コード例 #15
0
 async def _paillier_setup(self):
     # Load paillier keys on server
     key_inputter = await self._executor.create_value(self._key_inputter)
     _check_key_inputter(key_inputter)
     fed_output = await self._eval(key_inputter, tff.SERVER, all_equal=True)
     output = fed_output.internal_representation[0]
     # Broadcast encryption key to all placements
     server_executor = self._get_child_executors(tff.SERVER, index=0)
     ek_ref = await server_executor.create_selection(output, index=0)
     ek = federated_resolving_strategy.FederatedResolvingStrategyValue(
         ek_ref, tff.FederatedType(ek_ref.type_signature, tff.SERVER, True))
     placed = await asyncio.gather(
         self._move(ek, tff.SERVER, tff.CLIENTS),
         self._move(ek, tff.SERVER, paillier_placement.AGGREGATOR))
     self.encryption_key_server = ek
     self.encryption_key_clients = placed[0]
     self.encryption_key_paillier = placed[1]
     # Keep decryption key on server with formal placement
     dk_ref = await server_executor.create_selection(output, index=1)
     self.decryption_key = federated_resolving_strategy.FederatedResolvingStrategyValue(
         dk_ref,
         tff.FederatedType(dk_ref.type_signature,
                           tff.SERVER,
                           all_equal=True))
コード例 #16
0
  def test_returns_value_with_structure_value(self):
    tensor_type = computation_types.TensorType(tf.float32, shape=[])
    element = eager_tf_executor.EagerValue(10.0, tensor_type)
    element_type = computation_types.TensorType(tf.float32)
    names = ['a', 'b', 'c']
    value = structure.Struct((n, element) for n in names)
    type_signature = computation_types.StructType(
        (n, element_type) for n in names)
    value = federated_resolving_strategy.FederatedResolvingStrategyValue(
        value, type_signature)

    result = self.run_sync(value.compute())

    expected_result = structure.Struct((n, 10.0) for n in names)
    self.assertEqual(result, expected_result)
コード例 #17
0
 async def _compute_paillier_sum(
         self, encryption_key: federated_resolving_strategy.
     FederatedResolvingStrategyValue, values: federated_resolving_strategy.
     FederatedResolvingStrategyValue):
     paillier_child = self._get_child_executors(
         paillier_placement.AGGREGATOR, index=0)
     sum_proto, sum_type = utils.lift_to_computation_spec(
         self._paillier_sequence_sum,
         input_arg_type=tff.StructType(
             (encryption_key.type_signature.member,
              tff.StructType([vt.member for vt in values.type_signature]))))
     sum_fn = paillier_child.create_value(sum_proto, sum_type)
     sum_arg = paillier_child.create_struct(
         (encryption_key.internal_representation, await
          paillier_child.create_struct(values.internal_representation)))
     sum_fn, sum_arg = await asyncio.gather(sum_fn, sum_arg)
     encrypted_sum = await paillier_child.create_call(sum_fn, sum_arg)
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         encrypted_sum,
         tff.FederatedType(sum_type.result, paillier_placement.AGGREGATOR,
                           True))
コード例 #18
0
    async def _compute_modulus(self, value, mask):
        async def build_modulus_argument(value, mask):
            # Create the mask at the same placement as value.
            placed_mask = await self._executor.create_value(
                await mask.compute(),
                computation_types.FederatedType(mask.type_signature,
                                                value.type_signature.placement,
                                                all_equal=True))
            arg_struct = await self._executor.create_struct(
                [value, placed_mask])
            if value.type_signature.placement == placements.SERVER:
                return await self.compute_federated_zip_at_server(arg_struct)
            elif value.type_signature.placement == placements.CLIENTS:
                return await self.compute_federated_zip_at_clients(arg_struct)
            else:
                raise TypeError(
                    'Unknown placement [{p}], must be one of [CLIENTS, SERVER]'
                    .format(p=value.type_signature.placement))

        modulus_comp_coro = self._executor.create_value(
            *tensorflow_computation_factory.create_binary_operator_with_upcast(
                computation_types.StructType([
                    value.type_signature.member, mask.type_signature
                ]), tf.bitwise.bitwise_and))

        modulus_comp, modulus_comp_arg = await asyncio.gather(
            modulus_comp_coro, build_modulus_argument(value, mask))
        map_arg = federated_resolving_strategy.FederatedResolvingStrategyValue(
            structure.Struct([
                (None, modulus_comp.internal_representation),
                (None, modulus_comp_arg.internal_representation),
            ]),
            computation_types.StructType(
                [modulus_comp.type_signature,
                 modulus_comp_arg.type_signature]))
        if value.type_signature.all_equal:
            return await self.compute_federated_map_all_equal(map_arg)
        else:
            return await self.compute_federated_map(map_arg)
コード例 #19
0
 async def _compute_reshape_on_tensor(self, tensor, output_shape):
     tensor_type = tensor.type_signature.member
     shape_type = type_conversions.infer_type(output_shape)
     reshaper_proto, reshaper_type = utils.materialize_computation_from_cache(
         paillier_comp.make_reshape_tensor,
         self._reshape_function_cache,
         arg_spec=(tensor_type, ),
         output_shape=output_shape)
     tensor_placement = tensor.type_signature.placement
     children = self._get_child_executors(tensor_placement)
     py_typecheck.check_len(tensor.internal_representation, len(children))
     reshaper_fns = await asyncio.gather(*[
         ex.create_value(reshaper_proto, reshaper_type) for ex in children
     ])
     reshaped_tensors = await asyncio.gather(*[
         ex.create_call(fn, arg) for ex, fn, arg in zip(
             children, reshaper_fns, tensor.internal_representation)
     ])
     output_tensor_spec = tff.FederatedType(
         tff.TensorType(tensor_type.dtype, output_shape), tensor_placement,
         tensor.type_signature.all_equal)
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         reshaped_tensors, output_tensor_spec)
コード例 #20
0
 async def _compute_paillier_encryption(
     self, client_encryption_keys: federated_resolving_strategy.
     FederatedResolvingStrategyValue,
     clients_value: federated_resolving_strategy.
     FederatedResolvingStrategyValue):
     client_children = self._get_child_executors(tff.CLIENTS)
     num_clients = len(client_children)
     py_typecheck.check_len(client_encryption_keys.internal_representation,
                            num_clients)
     py_typecheck.check_len(clients_value.internal_representation,
                            num_clients)
     encryptor_proto, encryptor_type = utils.lift_to_computation_spec(
         self._paillier_encryptor,
         input_arg_type=tff.StructType(
             (client_encryption_keys.type_signature.member,
              clients_value.type_signature.member)))
     encryptor_fns = asyncio.gather(*[
         c.create_value(encryptor_proto, encryptor_type)
         for c in client_children
     ])
     encryptor_args = asyncio.gather(*[
         c.create_struct((ek, v))
         for c, ek, v in zip(client_children,
                             client_encryption_keys.internal_representation,
                             clients_value.internal_representation)
     ])
     encryptor_fns, encryptor_args = await asyncio.gather(
         encryptor_fns, encryptor_args)
     encrypted_values = await asyncio.gather(*[
         c.create_call(fn, arg) for c, fn, arg in zip(
             client_children, encryptor_fns, encryptor_args)
     ])
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         encrypted_values,
         tff.FederatedType(encryptor_type.result, tff.CLIENTS,
                           clients_value.type_signature.all_equal))
コード例 #21
0
 async def _share_public_key(self, key_owner, key_receiver):
     public_key = self.key_references.get_public_key(key_owner)
     children = self.strategy._get_child_executors(key_receiver)
     val = await public_key.compute()
     key_type = public_key.type_signature.member
     # we currently only support sharing n keys with 1 executor,
     # or sharing 1 key with n executors
     if isinstance(val, list):
         # sharing n keys with 1 executor
         py_typecheck.check_len(children, 1)
         executor = children[0]
         vals = [executor.create_value(v, key_type) for v in val]
         vals_type = tff.FederatedType(type_conversions.infer_type(val),
                                       key_receiver)
     else:
         # sharing 1 key with n executors
         # val is a single tensor
         vals = [c.create_value(val, key_type) for c in children]
         vals_type = tff.FederatedType(key_type,
                                       key_receiver,
                                       all_equal=True)
     public_key_rcv = federated_resolving_strategy.FederatedResolvingStrategyValue(
         await asyncio.gather(*vals), vals_type)
     self.key_references.update_keys(key_owner, public_key=public_key_rcv)
コード例 #22
0
 async def _encrypt_values_on_singleton(self, val, sender, receiver):
     ###
     # we can safely assume  sender has cardinality=1 when receiver is CLIENTS
     ###
     # Case 1: receiver=CLIENTS
     #     plaintext: Fed(Tensor, sender, all_equal=True)
     #     pk_receiver: Fed(Tuple(Tensor), sender, all_equal=True)
     #     sk_sender: Fed(Tensor, sender, all_equal=True)
     #   Returns:
     #     encrypted_values: Tuple(Fed(Tensor, sender, all_equal=True))
     ###
     ### Check proper key placement
     sk_sender = self.key_references.get_secret_key(sender)
     pk_receiver = self.key_references.get_public_key(receiver)
     type_analysis.check_federated_type(sk_sender.type_signature,
                                        placement=sender)
     assert sk_sender.type_signature.placement is sender
     assert pk_receiver.type_signature.placement is sender
     ### Check placement cardinalities
     rcv_children = self.strategy._get_child_executors(receiver)
     snd_children = self.strategy._get_child_executors(sender)
     py_typecheck.check_len(snd_children, 1)
     snd_child = snd_children[0]
     ### Check value cardinalities
     type_analysis.check_federated_type(val.type_signature,
                                        placement=sender)
     py_typecheck.check_len(val.internal_representation, 1)
     py_typecheck.check_type(pk_receiver.type_signature.member,
                             tff.StructType)
     py_typecheck.check_len(pk_receiver.internal_representation,
                            len(rcv_children))
     py_typecheck.check_len(sk_sender.internal_representation, 1)
     ### Materialize encryptor function definition & type spec
     input_type = val.type_signature.member
     self._input_type_cache = input_type
     pk_rcv_type = pk_receiver.type_signature.member
     sk_snd_type = sk_sender.type_signature.member
     pk_element_type = pk_rcv_type[0]
     encryptor_arg_spec = (input_type, pk_element_type, sk_snd_type)
     encryptor_proto, encryptor_type = utils.materialize_computation_from_cache(
         sodium_comp.make_encryptor, self._encryptor_cache,
         encryptor_arg_spec)
     ### Prepare encryption arguments
     v = val.internal_representation[0]
     sk = sk_sender.internal_representation[0]
     ### Encrypt values and return them
     encryptor_fn = await snd_child.create_value(encryptor_proto,
                                                 encryptor_type)
     encryptor_args = await asyncio.gather(*[
         snd_child.create_struct([v, this_pk, sk])
         for this_pk in pk_receiver.internal_representation
     ])
     encrypted_values = await asyncio.gather(*[
         snd_child.create_call(encryptor_fn, arg) for arg in encryptor_args
     ])
     encrypted_value_types = [encryptor_type.result] * len(encrypted_values)
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         structure.from_container(encrypted_values),
         tff.StructType([
             tff.FederatedType(evt, sender, all_equal=False)
             for evt in encrypted_value_types
         ]))