def test_raises_type_error(self): type_analysis.check_valid_federated_weighted_mean_argument_tuple_type( computation_types.to_type([type_factory.at_clients(tf.float32)] * 2)) with self.assertRaises(TypeError): type_analysis.check_valid_federated_weighted_mean_argument_tuple_type( computation_types.to_type([type_factory.at_clients(tf.int32)] * 2))
async def compute_intrinsic_federated_weighted_mean( executor: executor_base.Executor, arg: executor_value_base.ExecutorValue ) -> executor_value_base.ExecutorValue: """Computes a federated weighted mean on the given `executor`. Args: executor: The executor to use. arg: The argument to embedded in `executor`. Returns: The result embedded in `executor`. """ type_analysis.check_valid_federated_weighted_mean_argument_tuple_type( arg.type_signature) zip1_type = computation_types.FunctionType( computation_types.StructType([ computation_types.at_clients(arg.type_signature[0].member), computation_types.at_clients(arg.type_signature[1].member) ]), computation_types.at_clients( computation_types.StructType( [arg.type_signature[0].member, arg.type_signature[1].member]))) multiply_blk = building_block_factory.create_tensorflow_binary_operator_with_upcast( zip1_type.result.member, tf.multiply) map_type = computation_types.FunctionType( computation_types.StructType( [multiply_blk.type_signature, zip1_type.result]), computation_types.at_clients(multiply_blk.type_signature.result)) sum1_type = computation_types.FunctionType( computation_types.at_clients(map_type.result.member), computation_types.at_server(map_type.result.member)) sum2_type = computation_types.FunctionType( computation_types.at_clients(arg.type_signature[1].member), computation_types.at_server(arg.type_signature[1].member)) zip2_type = computation_types.FunctionType( computation_types.StructType([sum1_type.result, sum2_type.result]), computation_types.at_server( computation_types.StructType( [sum1_type.result.member, sum2_type.result.member]))) divide_blk = building_block_factory.create_tensorflow_binary_operator_with_upcast( zip2_type.result.member, tf.divide) async def _compute_multiply_fn(): return await executor.create_value(multiply_blk.proto, multiply_blk.type_signature) async def _compute_multiply_arg(): zip1_comp = create_intrinsic_comp( intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip1_type) zip_fn = await executor.create_value(zip1_comp, zip1_type) return await executor.create_call(zip_fn, arg) async def _compute_product_fn(): map_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_MAP, map_type) return await executor.create_value(map_comp, map_type) async def _compute_product_arg(): multiply_fn, multiply_arg = await asyncio.gather( _compute_multiply_fn(), _compute_multiply_arg()) return await executor.create_struct((multiply_fn, multiply_arg)) async def _compute_products(): product_fn, product_arg = await asyncio.gather(_compute_product_fn(), _compute_product_arg()) return await executor.create_call(product_fn, product_arg) async def _compute_total_weight(): sum2_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum2_type) sum2_fn, sum2_arg = await asyncio.gather( executor.create_value(sum2_comp, sum2_type), executor.create_selection(arg, index=1)) return await executor.create_call(sum2_fn, sum2_arg) async def _compute_sum_of_products(): sum1_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum1_type) sum1_fn, products = await asyncio.gather( executor.create_value(sum1_comp, sum1_type), _compute_products()) return await executor.create_call(sum1_fn, products) async def _compute_zip2_fn(): zip2_comp = create_intrinsic_comp( intrinsic_defs.FEDERATED_ZIP_AT_SERVER, zip2_type) return await executor.create_value(zip2_comp, zip2_type) async def _compute_zip2_arg(): sum_of_products, total_weight = await asyncio.gather( _compute_sum_of_products(), _compute_total_weight()) return await executor.create_struct([sum_of_products, total_weight]) async def _compute_divide_fn(): return await executor.create_value(divide_blk.proto, divide_blk.type_signature) async def _compute_divide_arg(): zip_fn, zip_arg = await asyncio.gather(_compute_zip2_fn(), _compute_zip2_arg()) return await executor.create_call(zip_fn, zip_arg) async def _compute_apply_fn(): apply_type = computation_types.FunctionType( computation_types.StructType( [divide_blk.type_signature, zip2_type.result]), computation_types.at_server(divide_blk.type_signature.result)) apply_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_APPLY, apply_type) return await executor.create_value(apply_comp, apply_type) async def _compute_apply_arg(): divide_fn, divide_arg = await asyncio.gather(_compute_divide_fn(), _compute_divide_arg()) return await executor.create_struct([divide_fn, divide_arg]) async def _compute_divided(): apply_fn, apply_arg = await asyncio.gather(_compute_apply_fn(), _compute_apply_arg()) return await executor.create_call(apply_fn, apply_arg) return await _compute_divided()
async def compute_intrinsic_federated_weighted_mean( executor: executor_base.Executor, arg: executor_value_base.ExecutorValue, local_computation_factory: local_computation_factory_base. LocalComputationFactory = tensorflow_computation_factory. TensorFlowComputationFactory() ) -> executor_value_base.ExecutorValue: """Computes a federated weighted mean on the given `executor`. Args: executor: The executor to use. arg: The argument to embedded in `executor`. local_computation_factory: An instance of `LocalComputationFactory` to use to construct local computations used as parameters in certain federated operators (such as `tff.federated_sum`, etc.). Defaults to a TensorFlow computation factory that generates TensorFlow code. Returns: The result embedded in `executor`. """ type_analysis.check_valid_federated_weighted_mean_argument_tuple_type( arg.type_signature) zip1_type = computation_types.FunctionType( computation_types.StructType([ computation_types.at_clients(arg.type_signature[0].member), computation_types.at_clients(arg.type_signature[1].member) ]), computation_types.at_clients( computation_types.StructType( [arg.type_signature[0].member, arg.type_signature[1].member]))) operand_type = zip1_type.result.member[0] scalar_type = zip1_type.result.member[1] multiply_comp_pb, multiply_comp_type = local_computation_factory.create_scalar_multiply_operator( operand_type, scalar_type) multiply_blk = building_blocks.CompiledComputation( multiply_comp_pb, type_signature=multiply_comp_type) map_type = computation_types.FunctionType( computation_types.StructType( [multiply_blk.type_signature, zip1_type.result]), computation_types.at_clients(multiply_blk.type_signature.result)) sum1_type = computation_types.FunctionType( computation_types.at_clients(map_type.result.member), computation_types.at_server(map_type.result.member)) sum2_type = computation_types.FunctionType( computation_types.at_clients(arg.type_signature[1].member), computation_types.at_server(arg.type_signature[1].member)) zip2_type = computation_types.FunctionType( computation_types.StructType([sum1_type.result, sum2_type.result]), computation_types.at_server( computation_types.StructType( [sum1_type.result.member, sum2_type.result.member]))) divide_blk = building_block_factory.create_tensorflow_binary_operator_with_upcast( zip2_type.result.member, tf.divide) async def _compute_multiply_fn(): return await executor.create_value(multiply_blk.proto, multiply_blk.type_signature) async def _compute_multiply_arg(): zip1_comp = create_intrinsic_comp( intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip1_type) zip_fn = await executor.create_value(zip1_comp, zip1_type) return await executor.create_call(zip_fn, arg) async def _compute_product_fn(): map_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_MAP, map_type) return await executor.create_value(map_comp, map_type) async def _compute_product_arg(): multiply_fn, multiply_arg = await asyncio.gather( _compute_multiply_fn(), _compute_multiply_arg()) return await executor.create_struct((multiply_fn, multiply_arg)) async def _compute_products(): product_fn, product_arg = await asyncio.gather(_compute_product_fn(), _compute_product_arg()) return await executor.create_call(product_fn, product_arg) async def _compute_total_weight(): sum2_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum2_type) sum2_fn, sum2_arg = await asyncio.gather( executor.create_value(sum2_comp, sum2_type), executor.create_selection(arg, 1)) return await executor.create_call(sum2_fn, sum2_arg) async def _compute_sum_of_products(): sum1_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum1_type) sum1_fn, products = await asyncio.gather( executor.create_value(sum1_comp, sum1_type), _compute_products()) return await executor.create_call(sum1_fn, products) async def _compute_zip2_fn(): zip2_comp = create_intrinsic_comp( intrinsic_defs.FEDERATED_ZIP_AT_SERVER, zip2_type) return await executor.create_value(zip2_comp, zip2_type) async def _compute_zip2_arg(): sum_of_products, total_weight = await asyncio.gather( _compute_sum_of_products(), _compute_total_weight()) return await executor.create_struct([sum_of_products, total_weight]) async def _compute_divide_fn(): return await executor.create_value(divide_blk.proto, divide_blk.type_signature) async def _compute_divide_arg(): zip_fn, zip_arg = await asyncio.gather(_compute_zip2_fn(), _compute_zip2_arg()) return await executor.create_call(zip_fn, zip_arg) async def _compute_apply_fn(): apply_type = computation_types.FunctionType( computation_types.StructType( [divide_blk.type_signature, zip2_type.result]), computation_types.at_server(divide_blk.type_signature.result)) apply_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_APPLY, apply_type) return await executor.create_value(apply_comp, apply_type) async def _compute_apply_arg(): divide_fn, divide_arg = await asyncio.gather(_compute_divide_fn(), _compute_divide_arg()) return await executor.create_struct([divide_fn, divide_arg]) async def _compute_divided(): apply_fn, apply_arg = await asyncio.gather(_compute_apply_fn(), _compute_apply_arg()) return await executor.create_call(apply_fn, apply_arg) return await _compute_divided()
async def compute_intrinsic_federated_weighted_mean( executor: executor_base.Executor, arg: executor_value_base.ExecutorValue ) -> executor_value_base.ExecutorValue: """Computes a federated weighted mean on the given `executor`. Args: executor: The executor to use. arg: The argument to embedded in `executor`. Returns: The result embedded in `executor`. """ type_analysis.check_valid_federated_weighted_mean_argument_tuple_type( arg.type_signature) zip1_type = computation_types.FunctionType( computation_types.NamedTupleType([ type_factory.at_clients(arg.type_signature[0].member), type_factory.at_clients(arg.type_signature[1].member) ]), type_factory.at_clients( computation_types.NamedTupleType( [arg.type_signature[0].member, arg.type_signature[1].member]))) zip1_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip1_type) zipped_arg = await executor.create_call( await executor.create_value(zip1_comp, zip1_type), arg) # TODO(b/134543154): Replace with something that produces a section of # plain TensorFlow code instead of constructing a lambda (so that this # can be executed directly on top of a plain TensorFlow-based executor). multiply_blk = building_block_factory.create_binary_operator_with_upcast( zipped_arg.type_signature.member, tf.multiply) map_type = computation_types.FunctionType( computation_types.NamedTupleType( [multiply_blk.type_signature, zipped_arg.type_signature]), type_factory.at_clients(multiply_blk.type_signature.result)) map_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_MAP, map_type) products = await executor.create_call( await executor.create_value(map_comp, map_type), await executor.create_tuple([ await executor.create_value(multiply_blk.proto, multiply_blk.type_signature), zipped_arg ])) sum1_type = computation_types.FunctionType( type_factory.at_clients(products.type_signature.member), type_factory.at_server(products.type_signature.member)) sum1_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum1_type) sum_of_products = await executor.create_call( await executor.create_value(sum1_comp, sum1_type), products) sum2_type = computation_types.FunctionType( type_factory.at_clients(arg.type_signature[1].member), type_factory.at_server(arg.type_signature[1].member)) sum2_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum2_type) total_weight = await executor.create_call( *(await asyncio.gather(executor.create_value(sum2_comp, sum2_type), executor.create_selection(arg, index=1)))) zip2_type = computation_types.FunctionType( computation_types.NamedTupleType( [sum_of_products.type_signature, total_weight.type_signature]), type_factory.at_server( computation_types.NamedTupleType([ sum_of_products.type_signature.member, total_weight.type_signature.member ]))) zip2_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_ZIP_AT_SERVER, zip2_type) divide_arg = await executor.create_call(*(await asyncio.gather( executor.create_value(zip2_comp, zip2_type), executor.create_tuple([sum_of_products, total_weight])))) divide_blk = building_block_factory.create_binary_operator_with_upcast( divide_arg.type_signature.member, tf.divide) apply_type = computation_types.FunctionType( computation_types.NamedTupleType( [divide_blk.type_signature, divide_arg.type_signature]), type_factory.at_server(divide_blk.type_signature.result)) apply_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_APPLY, apply_type) return await executor.create_call(*(await asyncio.gather( executor.create_value(apply_comp, apply_type), executor.create_tuple([ await executor.create_value(divide_blk.proto, divide_blk.type_signature), divide_arg ]))))