async def compute_intrinsic_federated_broadcast( executor: executor_base.Executor, arg: executor_value_base.ExecutorValue ) -> executor_value_base.ExecutorValue: """Computes a federated broadcast on the given `executor`. Args: executor: The executor to use. arg: The value to broadcast. Expected to be embedded in the `executor` and have federated type placed at `tff.SERVER` with all_equal of `True`. Returns: The result embedded in `executor`. Raises: TypeError: If the arguments are of the wrong types. """ py_typecheck.check_type(executor, executor_base.Executor) py_typecheck.check_type(arg, executor_value_base.ExecutorValue) type_analysis.check_federated_type(arg.type_signature, placement=placement_literals.SERVER, all_equal=True) value = await arg.compute() type_signature = computation_types.FederatedType( arg.type_signature.member, placement_literals.CLIENTS, all_equal=True) return await executor.create_value(value, type_signature)
async def compute_federated_mean( self, arg: FederatedComposingStrategyValue ) -> FederatedComposingStrategyValue: type_analysis.check_federated_type( arg.type_signature, placement=placement_literals.CLIENTS) member_type = arg.type_signature.member async def _create_total(): total = await self.compute_federated_sum(arg) total = await total.compute() return await self._server_executor.create_value(total, member_type) async def _create_factor(): cardinalities = await self._get_cardinalities() count = sum(cardinalities) return await executor_utils.embed_tf_scalar_constant( self._server_executor, member_type, float(1.0 / count)) async def _create_multiply_arg(): total, factor = await asyncio.gather(_create_total(), _create_factor()) return await self._server_executor.create_struct([total, factor]) multiply_fn, multiply_arg = await asyncio.gather( executor_utils.embed_tf_binary_operator(self._server_executor, member_type, tf.multiply), _create_multiply_arg()) result = await self._server_executor.create_call( multiply_fn, multiply_arg) type_signature = type_factory.at_server(member_type) return FederatedComposingStrategyValue(result, type_signature)
async def _encrypt_values_on_clients(self, val, sender, receiver): ### # Case 2: sender=CLIENTS # plaintext: Fed(Tensor, CLIENTS, all_equal=False) # pk_receiver: Fed(Tensor, CLIENTS, all_equal=True) # sk_sender: Fed(Tensor, CLIENTS, all_equal=False) # Returns: # encrypted_values: Fed(Tensor, CLIENTS, all_equal=False) ### ### Check proper key placement sk_sender = self.key_references.get_secret_key(sender) pk_receiver = self.key_references.get_public_key(receiver) type_analysis.check_federated_type(sk_sender.type_signature, placement=sender) assert sk_sender.type_signature.placement is sender assert pk_receiver.type_signature.placement is sender ### Check placement cardinalities snd_children = self.strategy._get_child_executors(sender) rcv_children = self.strategy._get_child_executors(receiver) py_typecheck.check_len(rcv_children, 1) ### Check value cardinalities type_analysis.check_federated_type(val.type_signature, placement=sender) federated_value_internals = [ val.internal_representation, pk_receiver.internal_representation, sk_sender.internal_representation ] for v in federated_value_internals: py_typecheck.check_len(v, len(snd_children)) ### Materialize encryptor function definition & type spec input_type = val.type_signature.member self._input_type_cache = input_type pk_rcv_type = pk_receiver.type_signature.member sk_snd_type = sk_sender.type_signature.member pk_element_type = pk_rcv_type encryptor_arg_spec = (input_type, pk_element_type, sk_snd_type) encryptor_proto, encryptor_type = utils.materialize_computation_from_cache( sodium_comp.make_encryptor, self._encryptor_cache, encryptor_arg_spec) ### Encrypt values and return them encryptor_fns = asyncio.gather(*[ snd_child.create_value(encryptor_proto, encryptor_type) for snd_child in snd_children ]) encryptor_args = asyncio.gather(*[ snd_child.create_struct([v, pk, sk]) for v, pk, sk, snd_child in zip(*federated_value_internals, snd_children) ]) encryptor_fns, encryptor_args = await asyncio.gather( encryptor_fns, encryptor_args) encrypted_values = [ snd_child.create_call(encryptor, arg) for encryptor, arg, snd_child in zip(encryptor_fns, encryptor_args, snd_children) ] return federated_resolving_strategy.FederatedResolvingStrategyValue( await asyncio.gather(*encrypted_values), tff.FederatedType(encryptor_type.result, sender, all_equal=val.type_signature.all_equal))
async def _decrypt_values_on_clients(self, val, sender, receiver): ### Check proper key placement pk_sender = self.key_references.get_public_key(sender) sk_receiver = self.key_references.get_secret_key(receiver) type_analysis.check_federated_type(pk_sender.type_signature, placement=receiver) type_analysis.check_federated_type(sk_receiver.type_signature, placement=receiver) pk_snd_type = pk_sender.type_signature.member sk_rcv_type = sk_receiver.type_signature.member ### Check value cardinalities rcv_children = self.strategy._get_child_executors(receiver) federated_value_internals = [ val.internal_representation, pk_sender.internal_representation, sk_receiver.internal_representation ] for fv in federated_value_internals: py_typecheck.check_len(fv, len(rcv_children)) ### Materialize decryptor type_spec & function definition input_type = val.type_signature.member # input_type[0] is a tff.TensorType, thus input_type represents the # tuple needed for a single value to be decrypted. py_typecheck.check_type(input_type[0], tff.TensorType) py_typecheck.check_type(pk_snd_type, tff.TensorType) input_element_type = input_type pk_element_type = pk_snd_type decryptor_arg_spec = (input_element_type, pk_element_type, sk_rcv_type) decryptor_proto, decryptor_type = utils.materialize_computation_from_cache( sodium_comp.make_decryptor, self._decryptor_cache, decryptor_arg_spec, orig_tensor_dtype=self._input_type_cache.dtype) ### Decrypt values and return them decryptor_fns = asyncio.gather(*[ rcv_child.create_value(decryptor_proto, decryptor_type) for rcv_child in rcv_children ]) decryptor_args = asyncio.gather(*[ rcv_child.create_struct([v, pk, sk]) for v, pk, sk, rcv_child in zip(*federated_value_internals, rcv_children) ]) decryptor_fns, decryptor_args = await asyncio.gather( decryptor_fns, decryptor_args) decrypted_values = [ rcv_child.create_call(decryptor, arg) for decryptor, arg, rcv_child in zip(decryptor_fns, decryptor_args, rcv_children) ] return federated_resolving_strategy.FederatedResolvingStrategyValue( await asyncio.gather(*decrypted_values), tff.FederatedType(decryptor_type.result, receiver, all_equal=val.type_signature.all_equal))
async def compute_federated_sum( self, arg: FederatedComposingStrategyValue ) -> FederatedComposingStrategyValue: type_analysis.check_federated_type( arg.type_signature, placement=placement_literals.CLIENTS) id_comp, id_type = tensorflow_computation_factory.create_identity( arg.type_signature.member) zero, plus, identity = await asyncio.gather( executor_utils.embed_tf_constant(self._executor, arg.type_signature.member, 0), executor_utils.embed_tf_binary_operator(self._executor, arg.type_signature.member, tf.add), self._executor.create_value(id_comp, id_type)) aggregate_args = await self._executor.create_struct( [arg, zero, plus, plus, identity]) return await self.compute_federated_aggregate(aggregate_args)
async def _compute_intrinsic_federated_sum(self, arg): type_analysis.check_federated_type( arg.type_signature, placement=placement_literals.CLIENTS) zero, plus, identity = tuple(await asyncio.gather(*[ executor_utils.embed_tf_scalar_constant( self, arg.type_signature.member, 0), executor_utils.embed_tf_binary_operator( self, arg.type_signature.member, tf.add), self.create_value( computation_factory.create_lambda_identity( arg.type_signature.member), type_factory.unary_op(arg.type_signature.member)) ])) aggregate_args = await self.create_tuple( [arg, zero, plus, plus, identity]) return await self._compute_intrinsic_federated_aggregate(aggregate_args )
async def _compute_intrinsic_federated_collect(self, arg): py_typecheck.check_type(arg.type_signature, computation_types.FederatedType) type_analysis.check_federated_type( arg.type_signature, placement=placement_literals.CLIENTS) val = arg.internal_representation py_typecheck.check_type(val, list) member_type = arg.type_signature.member child = self._target_executors[placement_literals.SERVER][0] collected_items = await child.create_value( await asyncio.gather(*[v.compute() for v in val]), computation_types.SequenceType(member_type)) return FederatingExecutorValue( [collected_items], computation_types.FederatedType( computation_types.SequenceType(member_type), placement_literals.SERVER, all_equal=True))
async def compute_federated_zip_at_server( self, arg: FederatedComposingStrategyValue) -> FederatedComposingStrategyValue: py_typecheck.check_type(arg.type_signature, computation_types.StructType) py_typecheck.check_len(arg.type_signature, 2) py_typecheck.check_type(arg.internal_representation, structure.Struct) py_typecheck.check_len(arg.internal_representation, 2) for n in [0, 1]: type_analysis.check_federated_type( arg.type_signature[n], placement=placement_literals.SERVER, all_equal=True) return FederatedComposingStrategyValue( await self._server_executor.create_struct( [arg.internal_representation[n] for n in [0, 1]]), computation_types.at_server( computation_types.StructType( [arg.type_signature[0].member, arg.type_signature[1].member])))
async def compute_federated_apply( self, arg: FederatedComposingStrategyValue) -> FederatedComposingStrategyValue: py_typecheck.check_type(arg.internal_representation, structure.Struct) py_typecheck.check_len(arg.internal_representation, 2) fn_type = arg.type_signature[0] py_typecheck.check_type(fn_type, computation_types.FunctionType) val_type = arg.type_signature[1] type_analysis.check_federated_type( val_type, fn_type.parameter, placement_literals.SERVER, all_equal=True) fn = arg.internal_representation[0] py_typecheck.check_type(fn, pb.Computation) val = arg.internal_representation[1] py_typecheck.check_type(val, executor_value_base.ExecutorValue) return FederatedComposingStrategyValue( await self._server_executor.create_call( await self._server_executor.create_value(fn, fn_type), val), computation_types.at_server(fn_type.result))
async def compute_federated_zip_at_clients( self, arg: FederatedComposingStrategyValue ) -> FederatedComposingStrategyValue: py_typecheck.check_type(arg.type_signature, computation_types.StructType) py_typecheck.check_len(arg.type_signature, 2) py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 2) keys = [k for k, _ in anonymous_tuple.to_elements(arg.type_signature)] vals = [arg.internal_representation[n] for n in [0, 1]] types = [arg.type_signature[n] for n in [0, 1]] for n in [0, 1]: type_analysis.check_federated_type( types[n], placement=placement_literals.CLIENTS) types[n] = type_factory.at_clients(types[n].member) py_typecheck.check_type(vals[n], list) py_typecheck.check_len(vals[n], len(self._target_executors)) item_type = computation_types.StructType([ ((keys[n], types[n].member) if keys[n] else types[n].member) for n in [0, 1] ]) result_type = type_factory.at_clients(item_type) zip_type = computation_types.FunctionType( computation_types.StructType([ ((keys[n], types[n]) if keys[n] else types[n]) for n in [0, 1] ]), result_type) zip_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip_type) async def _child_fn(ex, x, y): py_typecheck.check_type(x, executor_value_base.ExecutorValue) py_typecheck.check_type(y, executor_value_base.ExecutorValue) return await ex.create_call( await ex.create_value(zip_comp, zip_type), await ex.create_struct( anonymous_tuple.AnonymousTuple([(keys[0], x), (keys[1], y)]))) result = await asyncio.gather(*[ _child_fn(c, x, y) for c, x, y in zip(self._target_executors, vals[0], vals[1]) ]) return FederatedComposingStrategyValue(result, result_type)
async def _compute_intrinsic_federated_zip_at_server(self, arg): py_typecheck.check_type(arg.type_signature, computation_types.NamedTupleType) py_typecheck.check_len(arg.type_signature, 2) py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 2) for n in [0, 1]: type_analysis.check_federated_type( arg.type_signature[n], placement=placement_literals.SERVER, all_equal=True) return CompositeValue( await self._parent_executor.create_tuple( [arg.internal_representation[n] for n in [0, 1]]), type_factory.at_server( computation_types.NamedTupleType([ arg.type_signature[0].member, arg.type_signature[1].member ])))
async def _compute_intrinsic_federated_apply(self, arg): py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 2) fn_type = arg.type_signature[0] py_typecheck.check_type(fn_type, computation_types.FunctionType) val_type = arg.type_signature[1] type_analysis.check_federated_type(val_type, fn_type.parameter, placement_literals.SERVER, all_equal=True) fn = arg.internal_representation[0] py_typecheck.check_type(fn, pb.Computation) val = arg.internal_representation[1] py_typecheck.check_type(val, executor_value_base.ExecutorValue) return CompositeValue( await self._parent_executor.create_call( await self._parent_executor.create_value(fn, fn_type), val), type_factory.at_server(fn_type.result))
async def compute_federated_secure_sum(self, arg): self._check_arg_is_structure(arg) py_typecheck.check_len(arg.internal_representation, 2) value_type = arg.type_signature[0] type_analysis.check_federated_type(value_type, placement=tff.CLIENTS) py_typecheck.check_type(value_type.member, tff.TensorType) # Stash input dtype for later input_tensor_dtype = value_type.member.dtype # Paillier setup phase if self._requires_setup: await self._paillier_setup() self._requires_setup = False # Stash input shape, and reshape input tensor to matrix-form input_tensor_shape = value_type.member.shape if len(input_tensor_shape) != 2: clients_value = await self._compute_reshape_on_tensor( await self._executor.create_selection(arg, index=0), output_shape=[1, input_tensor_shape.num_elements()]) else: clients_value = await self._executor.create_selection(arg, index=0) # Encrypt summands on tff.CLIENTS encrypted_values = await self._compute_paillier_encryption( self.encryption_key_clients, clients_value) # Perform Paillier sum on ciphertexts encrypted_values = await self._move(encrypted_values, tff.CLIENTS, paillier_placement.AGGREGATOR) encrypted_sum = await self._compute_paillier_sum( self.encryption_key_paillier, encrypted_values) # Move to server and decrypt the result encrypted_sum = await self._move(encrypted_sum, paillier_placement.AGGREGATOR, tff.SERVER) decrypted_result = await self._compute_paillier_decryption( self.decryption_key, self.encryption_key_server, encrypted_sum, export_dtype=input_tensor_dtype) return await self._compute_reshape_on_tensor( decrypted_result, output_shape=input_tensor_shape.as_list())
def test_raises_type_error(self): type_spec = computation_types.FederatedType(tf.int32, placement_literals.CLIENTS, False) type_analysis.check_federated_type(type_spec, tf.int32, placement_literals.CLIENTS, False) type_analysis.check_federated_type(type_spec, tf.int32, None, None) type_analysis.check_federated_type(type_spec, None, placement_literals.CLIENTS, None) type_analysis.check_federated_type(type_spec, None, None, False) self.assertRaises(TypeError, type_analysis.check_federated_type, type_spec, tf.bool, None, None) self.assertRaises(TypeError, type_analysis.check_federated_type, type_spec, None, placement_literals.SERVER, None) self.assertRaises(TypeError, type_analysis.check_federated_type, type_spec, None, None, True)
async def _decrypt_values_on_singleton(self, val, sender, receiver): ### Check proper key placement pk_sender = self.key_references.get_public_key(sender) sk_receiver = self.key_references.get_secret_key(receiver) type_analysis.check_federated_type(pk_sender.type_signature, placement=receiver) type_analysis.check_federated_type(sk_receiver.type_signature, placement=receiver) pk_snd_type = pk_sender.type_signature.member sk_rcv_type = sk_receiver.type_signature.member ### Check placement cardinalities snd_children = self.strategy._get_child_executors(sender) rcv_children = self.strategy._get_child_executors(receiver) py_typecheck.check_len(rcv_children, 1) rcv_child = rcv_children[0] ### Check value cardinalities py_typecheck.check_len(pk_sender.internal_representation, len(snd_children)) py_typecheck.check_len(sk_receiver.internal_representation, 1) ### Materialize decryptor type_spec & function definition py_typecheck.check_type(val.type_signature, tff.StructType) type_analysis.check_federated_type(val.type_signature[0], placement=receiver, all_equal=True) input_type = val.type_signature[0].member # each input_type is a tuple needed for one value to be decrypted py_typecheck.check_type(input_type, tff.StructType) py_typecheck.check_type(pk_snd_type, tff.StructType) py_typecheck.check_len(val.type_signature, len(pk_snd_type)) input_element_type = input_type pk_element_type = pk_snd_type[0] decryptor_arg_spec = (input_element_type, pk_element_type, sk_rcv_type) decryptor_proto, decryptor_type = utils.materialize_computation_from_cache( sodium_comp.make_decryptor, self._decryptor_cache, decryptor_arg_spec, orig_tensor_dtype=self._input_type_cache.dtype) ### Decrypt values and return them vals = val.internal_representation sk = sk_receiver.internal_representation[0] decryptor_fn = await rcv_child.create_value(decryptor_proto, decryptor_type) decryptor_args = await asyncio.gather(*[ rcv_child.create_struct([v, pk, sk]) for v, pk in zip(vals, pk_sender.internal_representation) ]) decrypted_values = await asyncio.gather(*[ rcv_child.create_call(decryptor_fn, arg) for arg in decryptor_args ]) decrypted_value_types = [decryptor_type.result] * len(decrypted_values) return federated_resolving_strategy.FederatedResolvingStrategyValue( structure.from_container(decrypted_values), tff.StructType([ tff.FederatedType(dvt, receiver, all_equal=True) for dvt in decrypted_value_types ]))
async def _encrypt_values_on_singleton(self, val, sender, receiver): ### # we can safely assume sender has cardinality=1 when receiver is CLIENTS ### # Case 1: receiver=CLIENTS # plaintext: Fed(Tensor, sender, all_equal=True) # pk_receiver: Fed(Tuple(Tensor), sender, all_equal=True) # sk_sender: Fed(Tensor, sender, all_equal=True) # Returns: # encrypted_values: Tuple(Fed(Tensor, sender, all_equal=True)) ### ### Check proper key placement sk_sender = self.key_references.get_secret_key(sender) pk_receiver = self.key_references.get_public_key(receiver) type_analysis.check_federated_type(sk_sender.type_signature, placement=sender) assert sk_sender.type_signature.placement is sender assert pk_receiver.type_signature.placement is sender ### Check placement cardinalities rcv_children = self.strategy._get_child_executors(receiver) snd_children = self.strategy._get_child_executors(sender) py_typecheck.check_len(snd_children, 1) snd_child = snd_children[0] ### Check value cardinalities type_analysis.check_federated_type(val.type_signature, placement=sender) py_typecheck.check_len(val.internal_representation, 1) py_typecheck.check_type(pk_receiver.type_signature.member, tff.StructType) py_typecheck.check_len(pk_receiver.internal_representation, len(rcv_children)) py_typecheck.check_len(sk_sender.internal_representation, 1) ### Materialize encryptor function definition & type spec input_type = val.type_signature.member self._input_type_cache = input_type pk_rcv_type = pk_receiver.type_signature.member sk_snd_type = sk_sender.type_signature.member pk_element_type = pk_rcv_type[0] encryptor_arg_spec = (input_type, pk_element_type, sk_snd_type) encryptor_proto, encryptor_type = utils.materialize_computation_from_cache( sodium_comp.make_encryptor, self._encryptor_cache, encryptor_arg_spec) ### Prepare encryption arguments v = val.internal_representation[0] sk = sk_sender.internal_representation[0] ### Encrypt values and return them encryptor_fn = await snd_child.create_value(encryptor_proto, encryptor_type) encryptor_args = await asyncio.gather(*[ snd_child.create_struct([v, this_pk, sk]) for this_pk in pk_receiver.internal_representation ]) encrypted_values = await asyncio.gather(*[ snd_child.create_call(encryptor_fn, arg) for arg in encryptor_args ]) encrypted_value_types = [encryptor_type.result] * len(encrypted_values) return federated_resolving_strategy.FederatedResolvingStrategyValue( structure.from_container(encrypted_values), tff.StructType([ tff.FederatedType(evt, sender, all_equal=False) for evt in encrypted_value_types ]))