def test_raises_not_implemented_error_with_intrinsic_def_federated_secure_sum( self): executor = create_test_executor() comp, comp_type = executor_test_utils.create_dummy_intrinsic_def_federated_secure_sum( ) arg_1 = [10, 11, 12] arg_1_type = computation_types.at_clients(tf.int32, all_equal=False) arg_2 = 10 arg_2_type = computation_types.TensorType(tf.int32) comp = self.run_sync(executor.create_value(comp, comp_type)) arg_1 = self.run_sync(executor.create_value(arg_1, arg_1_type)) arg_2 = self.run_sync(executor.create_value(arg_2, arg_2_type)) args = self.run_sync(executor.create_struct([arg_1, arg_2])) with self.assertRaises(NotImplementedError): self.run_sync(executor.create_call(comp, args))
def test_returns_value_with_unplaced_type_and_clients(self, executor): value, type_signature = executor_test_utils.create_dummy_value_unplaced( ) value = self.run_sync(executor.create_value(value, type_signature)) result = self.run_sync( executor_utils.compute_intrinsic_federated_value( executor, value, placements.CLIENTS)) self.assertIsInstance(result, executor_value_base.ExecutorValue) expected_type = computation_types.at_clients(type_signature, all_equal=True) self.assertEqual(result.type_signature.compact_representation(), expected_type.compact_representation()) actual_result = self.run_sync(result.compute()) self.assertEqual(actual_result, 10.0)
def _build_expected_broadcaster_next_signature(): """Returns signature of the broadcaster used in multiple tests below.""" state_type = computation_types.at_server( computation_types.StructType([('trainable', [ (), ]), ('non_trainable', [])])) value_type = computation_types.at_server( model_utils.weights_type_from_model(TestModelQuant)) result_type = computation_types.at_clients( model_utils.weights_type_from_model(TestModelQuant)) measurements_type = computation_types.at_server(()) return computation_types.FunctionType( parameter=collections.OrderedDict(state=state_type, value=value_type), result=collections.OrderedDict(state=state_type, result=result_type, measurements=measurements_type))
def test_changing_cardinalities_across_calls(self): @computations.federated_computation( computation_types.at_clients(tf.int32)) def comp(x): return x five_ints = list(range(5)) ten_ints = list(range(10)) executor = executor_stacks.local_executor_factory() with executor_test_utils.install_executor(executor): five = comp(five_ints) ten = comp(ten_ints) self.assertEqual(five, five_ints) self.assertEqual(ten, ten_ints)
def test_returns_value_with_federated_type_at_server( self, executor, num_clients): del num_clients # Unused. value, type_signature = executor_test_utils.create_dummy_value_at_server( ) value = self.run_sync(executor.create_value(value, type_signature)) result = self.run_sync( executor_utils.compute_intrinsic_federated_broadcast( executor, value)) self.assertIsInstance(result, executor_value_base.ExecutorValue) expected_type = computation_types.at_clients(type_signature.member, all_equal=True) self.assertEqual(result.type_signature.compact_representation(), expected_type.compact_representation()) actual_result = self.run_sync(result.compute()) self.assertEqual(actual_result, 10.0)
def test_returns_value_with_intrinsic_def_federated_secure_sum( self, value, bitwidth, expected_result): executor = create_test_executor() comp, comp_type = executor_test_utils.create_dummy_intrinsic_def_federated_secure_sum( ) value_type = computation_types.at_clients(tf.int32, all_equal=False) bitwidth_type = computation_types.TensorType(tf.int32) comp = self.run_sync(executor.create_value(comp, comp_type)) arg_1 = self.run_sync(executor.create_value(value, value_type)) arg_2 = self.run_sync(executor.create_value(bitwidth, bitwidth_type)) args = self.run_sync(executor.create_struct([arg_1, arg_2])) result = self.run_sync(executor.create_call(comp, args)) self.assertIsInstance(result, executor_value_base.ExecutorValue) self.assert_types_identical(result.type_signature, comp_type.result) actual_result = self.run_sync(result.compute()) self.assertEqual(actual_result, expected_result)
def create( self, value_type: factory.ValueType ) -> aggregation_process.AggregationProcess: _check_value_type(value_type) inner_agg_process = inner_agg_factory.create(value_type) clip_fn = make_clip_fn(value_type) @computations.federated_computation() def init_fn(): return init_fn_impl(inner_agg_process) @computations.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type)) def next_fn(state, value): return next_fn_impl(state, value, clip_fn, inner_agg_process) return aggregation_process.AggregationProcess(init_fn, next_fn)
def test_federated_sum_reduces_to_aggregate(self): uri = intrinsic_defs.FEDERATED_SUM.uri comp = building_blocks.Intrinsic( uri, computation_types.FunctionType( computation_types.at_clients(tf.float32), computation_types.at_server(tf.float32))) count_sum_before_reduction = _count_intrinsics(comp, uri) reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies( comp) count_sum_after_reduction = _count_intrinsics(reduced, uri) count_aggregations = _count_intrinsics( reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri) self.assertTrue(modified) self.assert_types_identical(comp.type_signature, reduced.type_signature) self.assertGreater(count_sum_before_reduction, 0) self.assertEqual(count_sum_after_reduction, 0) self.assertGreater(count_aggregations, 0)
def _create_tff_parallel_clients_with_dataset_reduce(): @tf.function def reduce_fn(x, y): return x + y @tf.function def dataset_reduce_fn(ds, initial_val): return ds.reduce(initial_val, reduce_fn) @computations.tf_computation(computation_types.SequenceType(tf.int64)) def dataset_reduce_fn_wrapper(ds): initial_val = tf.Variable(np.int64(1.0)) return dataset_reduce_fn(ds, initial_val) @computations.federated_computation( computation_types.at_clients(computation_types.SequenceType(tf.int64))) def parallel_client_run(client_datasets): return intrinsics.federated_map(dataset_reduce_fn_wrapper, client_datasets) return parallel_client_run
def test_type_properties(self, value_type, inner_agg_factory): agg_factory = dp_factory.DifferentiallyPrivateFactory( _test_dp_query, inner_agg_factory) self.assertIsInstance(agg_factory, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = agg_factory.create_unweighted(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) query_state = _test_dp_query.initial_global_state() query_state_type = type_conversions.type_from_tensors(query_state) query_metrics_type = type_conversions.type_from_tensors( _test_dp_query.derive_metrics(query_state)) inner_state_type = tf.int32 if inner_agg_factory else () server_state_type = computation_types.at_server( (query_state_type, inner_state_type)) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) inner_measurements_type = tf.int32 if inner_agg_factory else () expected_measurements_type = computation_types.at_server( collections.OrderedDict( query_metrics=query_metrics_type, record_agg_process=inner_measurements_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def _check_bound_process(bound_process: estimation_process.EstimationProcess, name: str): """Checks type properties for estimation process for bounds. The process must be an `EstimationProcess` with `next` function of type signature (<state@SERVER, NORM_TF_TYPE@CLIENTS> -> state@SERVER), and `report` with type signature (state@SERVER -> NORM_TF_TYPE@SERVER). Args: bound_process: A process to check. name: A string name for formatting error messages. """ py_typecheck.check_type(bound_process, estimation_process.EstimationProcess) next_parameter_type = bound_process.next.type_signature.parameter if not next_parameter_type.is_struct() or len(next_parameter_type) != 2: raise TypeError(f'`{name}.next` must take two arguments but found:\n' f'{next_parameter_type}') float_type_at_clients = computation_types.at_clients(NORM_TF_TYPE) if not next_parameter_type[1].is_assignable_from(float_type_at_clients): raise TypeError( f'Second argument of `{name}.next` must be assignable from ' f'{float_type_at_clients} but found {next_parameter_type[1]}') next_result_type = bound_process.next.type_signature.result if not bound_process.state_type.is_assignable_from(next_result_type): raise TypeError( f'Result type of `{name}.next` must consist of state only ' f'but found result type:\n{next_result_type}\n' f'while the state type is:\n{bound_process.state_type}') report_type = bound_process.report.type_signature.result estimated_value_type_at_server = computation_types.at_server( next_parameter_type[1].member) if not report_type.is_assignable_from(estimated_value_type_at_server): raise TypeError( f'Report type of `{name}.report` must be assignable from ' f'{estimated_value_type_at_server} but found {report_type}.')
def test_type_properties_with_inner_factory(self, value_type, weight_type): sum_factory = aggregators_test_utils.SumPlusOneFactory() mean_f = mean_factory.MeanFactory(value_sum_factory=sum_factory, weight_sum_factory=sum_factory) self.assertIsInstance(mean_f, factory.WeightedAggregationFactory) value_type = computation_types.to_type(value_type) weight_type = computation_types.to_type(weight_type) process = mean_f.create_weighted(value_type, weight_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) param_value_type = computation_types.FederatedType( value_type, placements.CLIENTS) result_value_type = computation_types.FederatedType( value_type, placements.SERVER) expected_state_type = computation_types.FederatedType( collections.OrderedDict(value_sum_process=tf.int32, weight_sum_process=tf.int32), placements.SERVER) expected_measurements_type = computation_types.FederatedType( collections.OrderedDict(value_sum_process=tf.int32, weight_sum_process=tf.int32), placements.SERVER) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=param_value_type, weight=computation_types.at_clients(weight_type)), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_federated_aggregate_with_federated_zero_fails(self): zero = intrinsics.federated_value(0, placements.SERVER) @computations.tf_computation([tf.int32, tf.int32]) def accumulate(accu, elem): return accu + elem # The operator to use during the second stage simply adds total and count. @computations.tf_computation([tf.int32, tf.int32]) def merge(x, y): return x + y # The operator to use during the final stage simply computes the ratio. @computations.tf_computation(tf.int32) def report(accu): return accu x = _mock_data_of_type(computation_types.at_clients(tf.int32)) with self.assertRaisesRegex( TypeError, 'Expected `zero` to be assignable to type int32, ' 'but was of incompatible type int32@SERVER'): intrinsics.federated_aggregate(x, zero, accumulate, merge, report)
def test_returns_value_with_intrinsic_def_federated_secure_sum( self, client_values, bitwidth, expected_result): executor = create_test_executor() value_type = computation_types.at_clients( type_conversions.infer_type(client_values[0])) bitwidth_type = type_conversions.infer_type(bitwidth) comp, comp_type = create_intrinsic_def_federated_secure_sum( value_type.member, bitwidth_type) comp = self.run_sync(executor.create_value(comp, comp_type)) arg_1 = self.run_sync(executor.create_value(client_values, value_type)) arg_2 = self.run_sync(executor.create_value(bitwidth, bitwidth_type)) args = self.run_sync(executor.create_struct([arg_1, arg_2])) result = self.run_sync(executor.create_call(comp, args)) self.assertIsInstance(result, executor_value_base.ExecutorValue) self.assert_types_identical(result.type_signature, comp_type.result) actual_result = self.run_sync(result.compute()) if isinstance(expected_result, structure.Struct): structure.map_structure(self.assertAllEqual, actual_result, expected_result) else: self.assertEqual(actual_result, expected_result)
def test_type_properties(self, encoder_fn): encoded_f = encoded_factory.EncodedSumFactory(encoder_fn) self.assertIsInstance(encoded_f, factory.UnweightedAggregationFactory) process = encoded_f.create_unweighted(_test_struct_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) self.assertIsNone(process.initialize.type_signature.parameter) server_state_type = process.initialize.type_signature.result # State structure should have one element per tensor aggregated, self.assertLen(server_state_type.member, 2) self.assertEqual(placements.SERVER, server_state_type.placement) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(_test_struct_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(_test_struct_type), measurements=computation_types.at_server(()))) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_with_temperature_sensor_example(self, executor): @computations.tf_computation(computation_types.SequenceType( tf.float32), tf.float32) def count_over(ds, t): return ds.reduce( np.float32(0), lambda n, x: n + tf.cast(tf.greater(x, t), tf.float32)) @computations.tf_computation(computation_types.SequenceType(tf.float32) ) def count_total(ds): return ds.reduce(np.float32(0.0), lambda n, _: n + 1.0) @computations.federated_computation( computation_types.at_clients( computation_types.SequenceType(tf.float32)), computation_types.at_server(tf.float32)) def comp(temperatures, threshold): return intrinsics.federated_mean( intrinsics.federated_map( count_over, intrinsics.federated_zip([ temperatures, intrinsics.federated_broadcast(threshold) ])), intrinsics.federated_map(count_total, temperatures)) with executor_test_utils.install_executor(executor): to_float = lambda x: tf.cast(x, tf.float32) temperatures = [ tf.data.Dataset.range(10).map(to_float), tf.data.Dataset.range(20).map(to_float), tf.data.Dataset.range(30).map(to_float), ] threshold = 15.0 result = comp(temperatures, threshold) self.assertAlmostEqual(result, 8.333, places=3)
def test_at_clients(self): type_spec = computation_types.TensorType(tf.bool) actual_type = computation_types.at_clients(type_spec) expected_type = computation_types.FederatedType(type_spec, placements.CLIENTS) self.assertEqual(actual_type, expected_type)
def create_whimsy_called_federated_collect(value_type=tf.int32): federated_type = computation_types.at_clients(value_type) value = building_blocks.Data('data', federated_type) return building_block_factory.create_federated_collect(value)
async def compute_intrinsic_federated_weighted_mean( executor: executor_base.Executor, arg: executor_value_base.ExecutorValue, local_computation_factory: local_computation_factory_base. LocalComputationFactory = tensorflow_computation_factory. TensorFlowComputationFactory() ) -> executor_value_base.ExecutorValue: """Computes a federated weighted mean on the given `executor`. Args: executor: The executor to use. arg: The argument to embedded in `executor`. local_computation_factory: An instance of `LocalComputationFactory` to use to construct local computations used as parameters in certain federated operators (such as `tff.federated_sum`, etc.). Defaults to a TensorFlow computation factory that generates TensorFlow code. Returns: The result embedded in `executor`. """ type_analysis.check_valid_federated_weighted_mean_argument_tuple_type( arg.type_signature) zip1_type = computation_types.FunctionType( computation_types.StructType([ computation_types.at_clients(arg.type_signature[0].member), computation_types.at_clients(arg.type_signature[1].member) ]), computation_types.at_clients( computation_types.StructType( [arg.type_signature[0].member, arg.type_signature[1].member]))) operand_type = zip1_type.result.member[0] scalar_type = zip1_type.result.member[1] multiply_comp_pb, multiply_comp_type = local_computation_factory.create_scalar_multiply_operator( operand_type, scalar_type) multiply_blk = building_blocks.CompiledComputation( multiply_comp_pb, type_signature=multiply_comp_type) map_type = computation_types.FunctionType( computation_types.StructType( [multiply_blk.type_signature, zip1_type.result]), computation_types.at_clients(multiply_blk.type_signature.result)) sum1_type = computation_types.FunctionType( computation_types.at_clients(map_type.result.member), computation_types.at_server(map_type.result.member)) sum2_type = computation_types.FunctionType( computation_types.at_clients(arg.type_signature[1].member), computation_types.at_server(arg.type_signature[1].member)) zip2_type = computation_types.FunctionType( computation_types.StructType([sum1_type.result, sum2_type.result]), computation_types.at_server( computation_types.StructType( [sum1_type.result.member, sum2_type.result.member]))) divide_blk = building_block_factory.create_tensorflow_binary_operator_with_upcast( zip2_type.result.member, tf.divide) async def _compute_multiply_fn(): return await executor.create_value(multiply_blk.proto, multiply_blk.type_signature) async def _compute_multiply_arg(): zip1_comp = create_intrinsic_comp( intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip1_type) zip_fn = await executor.create_value(zip1_comp, zip1_type) return await executor.create_call(zip_fn, arg) async def _compute_product_fn(): map_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_MAP, map_type) return await executor.create_value(map_comp, map_type) async def _compute_product_arg(): multiply_fn, multiply_arg = await asyncio.gather( _compute_multiply_fn(), _compute_multiply_arg()) return await executor.create_struct((multiply_fn, multiply_arg)) async def _compute_products(): product_fn, product_arg = await asyncio.gather(_compute_product_fn(), _compute_product_arg()) return await executor.create_call(product_fn, product_arg) async def _compute_total_weight(): sum2_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum2_type) sum2_fn, sum2_arg = await asyncio.gather( executor.create_value(sum2_comp, sum2_type), executor.create_selection(arg, 1)) return await executor.create_call(sum2_fn, sum2_arg) async def _compute_sum_of_products(): sum1_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum1_type) sum1_fn, products = await asyncio.gather( executor.create_value(sum1_comp, sum1_type), _compute_products()) return await executor.create_call(sum1_fn, products) async def _compute_zip2_fn(): zip2_comp = create_intrinsic_comp( intrinsic_defs.FEDERATED_ZIP_AT_SERVER, zip2_type) return await executor.create_value(zip2_comp, zip2_type) async def _compute_zip2_arg(): sum_of_products, total_weight = await asyncio.gather( _compute_sum_of_products(), _compute_total_weight()) return await executor.create_struct([sum_of_products, total_weight]) async def _compute_divide_fn(): return await executor.create_value(divide_blk.proto, divide_blk.type_signature) async def _compute_divide_arg(): zip_fn, zip_arg = await asyncio.gather(_compute_zip2_fn(), _compute_zip2_arg()) return await executor.create_call(zip_fn, zip_arg) async def _compute_apply_fn(): apply_type = computation_types.FunctionType( computation_types.StructType( [divide_blk.type_signature, zip2_type.result]), computation_types.at_server(divide_blk.type_signature.result)) apply_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_APPLY, apply_type) return await executor.create_value(apply_comp, apply_type) async def _compute_apply_arg(): divide_fn, divide_arg = await asyncio.gather(_compute_divide_fn(), _compute_divide_arg()) return await executor.create_struct([divide_fn, divide_arg]) async def _compute_divided(): apply_fn, apply_arg = await asyncio.gather(_compute_apply_fn(), _compute_apply_arg()) return await executor.create_call(apply_fn, apply_arg) return await _compute_divided()
def create_dummy_value_at_clients_all_equal(): """Returns a Python value and federated type at clients and all equal.""" value = 10.0 type_signature = computation_types.at_clients(tf.float32, all_equal=True) return value, type_signature
def create_dummy_value_at_clients(number_of_clients: int = 3): """Returns a Python value and federated type at clients.""" value = [float(x) for x in range(10, number_of_clients + 10)] type_signature = computation_types.at_clients(tf.float32) return value, type_signature
def create_dummy_intrinsic_def_federated_value_at_clients(): value = intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS type_signature = computation_types.FunctionType( tf.float32, computation_types.at_clients(tf.float32, all_equal=True)) return value, type_signature
def create_dummy_intrinsic_def_federated_sum(): value = intrinsic_defs.FEDERATED_SUM type_signature = computation_types.FunctionType( computation_types.at_clients(tf.float32), computation_types.at_server(tf.float32)) return value, type_signature
def create_dummy_intrinsic_def_federated_eval_at_clients(): value = intrinsic_defs.FEDERATED_EVAL_AT_CLIENTS type_signature = computation_types.FunctionType( computation_types.FunctionType(None, tf.float32), computation_types.at_clients(tf.float32)) return value, type_signature
def create_dummy_intrinsic_def_federated_broadcast(): value = intrinsic_defs.FEDERATED_BROADCAST type_signature = computation_types.FunctionType( computation_types.at_server(tf.float32), computation_types.at_clients(tf.float32, all_equal=True)) return value, type_signature
async def compute_federated_secure_sum( self, arg: federated_resolving_strategy.FederatedResolvingStrategyValue ) -> federated_resolving_strategy.FederatedResolvingStrategyValue: logging.warning( 'The implementation of the `tff.federated_secure_sum` intrinsic ' 'provided by the `tff.backends.test` runtime uses no cryptography.' ) py_typecheck.check_type(arg.internal_representation, structure.Struct) py_typecheck.check_len(arg.internal_representation, 2) summands, bitwidth = await asyncio.gather( self.ingest_value(arg.internal_representation[0], arg.type_signature[0]).compute(), self.ingest_value(arg.internal_representation[1], arg.type_signature[1]).compute()) summands_type = arg.type_signature[0].member if not type_analysis.is_structure_of_integers(summands_type): raise TypeError( 'Cannot compute `federated_secure_sum` on summands that are not ' 'TensorType or StructType of TensorType. Got {t}'.format( t=repr(summands_type))) if (summands_type.is_struct() and not structure.is_same_structure(summands_type, bitwidth)): raise TypeError( 'Cannot compute `federated_secure_sum` if summands and bitwidth are ' 'not the same structure. Got summands={s}, bitwidth={b}'. format(s=repr(summands_type), b=repr(bitwidth.type_signature))) num_additional_bits = await self._compute_extra_bits_for_secagg() # Clamp to 64 bits, otherwise we can't represent the mask in TensorFlow. extended_bitwidth = _map_numpy_or_structure( bitwidth, fn=lambda b: min(b.numpy() + num_additional_bits, 64)) logging.debug('Emulated secure sum effective bitwidth: %s', extended_bitwidth) # Now we need to cast the summands into the integral type that is large # enough to represent the sum and the mask. summation_type_spec = _compute_summation_type_for_bitwidth( extended_bitwidth, summands_type) # `summands` is a list of all clients' summands. We map # `_map_numpy_or_structure` to the list, applying it pointwise to clients. summand_tensors = tf.nest.map_structure(_extract_numpy_arrays, summands) # Dtype conversion trick: pull the summand values out, and push them back # into the executor using the new dtypes decided based on bitwidth. casted_summands = await self._executor.create_value( summand_tensors, computation_types.at_clients(summation_type_spec)) # To emulate SecAgg without the random masks, we must mask the summands to # the effective bitwidth. This isn't strictly necessary because we also # mask the sum result and modulus operator is distributive, but this more # accurately reflects the system. mask = await self._embed_tf_secure_sum_mask_value( summation_type_spec, extended_bitwidth) masked_summands = await self._compute_modulus(casted_summands, mask) logging.debug('Computed masked modular summands as: %s', await masked_summands.compute()) # Then perform the sum and modolulo operation (using powers of 2 bitmasking) # on the sum, using the computed effective bitwidth. sum_result = await self.compute_federated_sum(masked_summands) modular_sums = await self._compute_modulus(sum_result, mask) # Dtype conversion trick again, pull the modular sum values out, and push # them back into the executor using the dypte from the summands. modular_sum_values = _extract_numpy_arrays(await modular_sums.compute()) logging.debug('Computed modular sums as: %s', modular_sum_values) return await self._executor.create_value( modular_sum_values, computation_types.at_server(summands_type))
def _encoded_next_fn(server_state_type, value_type, encoders): """Creates `next_fn` for the process returned by `EncodedSumFactory`. The structure of the implementation is roughly as follows: * Extract params for encoding/decoding from state (`get_params_fn`). * Encode values to be aggregated, placed at clients (`encode_fn`). * Call `federated_aggregate` operator, with decoding of the part which does not commute with sum, placed in its `accumulate_fn` arg. * Finish decoding the summed value placed at server (`decode_after_sum_fn`). * Update the state placed at server (`update_state_fn`). Args: server_state_type: A `tff.Type` of the expected state placed at server. value_type: An unplaced `tff.Type` of the value to be aggregated. encoders: A collection of `GatherEncoder` objects. Returns: A `tff.Computation` for `EncodedSumFactory`, with the type signature of `(server_state_type, value_type@CLIENTS) -> MeasuredProcessOutput(server_state_type, value_type@SERVER, ()@SERVER)` """ @computations.tf_computation(server_state_type.member) def get_params_fn(state): params = tree.map_structure_up_to(encoders, lambda e, s: e.get_params(s), encoders, state) encode_params = _slice(encoders, params, 0) decode_before_sum_params = _slice(encoders, params, 1) decode_after_sum_params = _slice(encoders, params, 2) return encode_params, decode_before_sum_params, decode_after_sum_params encode_params_type = get_params_fn.type_signature.result[0] decode_before_sum_params_type = get_params_fn.type_signature.result[1] decode_after_sum_params_type = get_params_fn.type_signature.result[2] # TODO(b/139844355): Get rid of decode_before_sum_params. # We pass decode_before_sum_params to the encode method, because TFF currently # does not have a mechanism to make a tff.SERVER placed value available inside # of intrinsics.federated_aggregate - in production, this could mean an # intermediary aggregator node. So currently, we send the params to clients, # and ask them to send them back as part of the encoded structure. @computations.tf_computation(value_type, encode_params_type, decode_before_sum_params_type) def encode_fn(x, encode_params, decode_before_sum_params): encoded_structure = tree.map_structure_up_to( encoders, lambda e, *args: e.encode(*args), encoders, x, encode_params) encoded_x = _slice(encoders, encoded_structure, 0) state_update_tensors = _slice(encoders, encoded_structure, 1) return encoded_x, decode_before_sum_params, state_update_tensors state_update_tensors_type = encode_fn.type_signature.result[2] # This is not a @computations.tf_computation because it will be used below # when bulding the computations.tf_computations that will compose a # intrinsics.federated_aggregate... def decode_before_sum_tf_function(encoded_x, decode_before_sum_params): part_decoded_x = tree.map_structure_up_to( encoders, lambda e, *args: e.decode_before_sum(*args), encoders, encoded_x, decode_before_sum_params) one = tf.constant((1, ), tf.int32) return part_decoded_x, one # ...however, result type is needed to build the subsequent tf_compuations. @computations.tf_computation(encode_fn.type_signature.result[0:2]) def tmp_decode_before_sum_fn(encoded_x, decode_before_sum_params): return decode_before_sum_tf_function(encoded_x, decode_before_sum_params) part_decoded_x_type = tmp_decode_before_sum_fn.type_signature.result del tmp_decode_before_sum_fn # Only needed for result type. @computations.tf_computation(part_decoded_x_type, decode_after_sum_params_type) def decode_after_sum_fn(summed_values, decode_after_sum_params): part_decoded_aggregated_x, num_summands = summed_values return tree.map_structure_up_to( encoders, lambda e, x, params: e.decode_after_sum(x, params, num_summands), encoders, part_decoded_aggregated_x, decode_after_sum_params) @computations.tf_computation(server_state_type.member, state_update_tensors_type) def update_state_fn(state, state_update_tensors): return tree.map_structure_up_to(encoders, lambda e, *args: e.update_state(*args), encoders, state, state_update_tensors) # Computations for intrinsics.federated_aggregate. def _accumulator_value(values, state_update_tensors): return collections.OrderedDict( values=values, state_update_tensors=state_update_tensors) @computations.tf_computation def zero_fn(): values = tf.nest.map_structure( lambda s: tf.zeros(s.shape, s.dtype), type_conversions.type_to_tf_tensor_specs(part_decoded_x_type)) state_update_tensors = tf.nest.map_structure( lambda s: tf.zeros(s.shape, s.dtype), type_conversions.type_to_tf_tensor_specs( state_update_tensors_type)) return _accumulator_value(values, state_update_tensors) accumulator_type = zero_fn.type_signature.result state_update_aggregation_modes = tf.nest.map_structure( lambda e: tuple(e.state_update_aggregation_modes), encoders) @computations.tf_computation(accumulator_type, encode_fn.type_signature.result) def accumulate_fn(acc, encoded_x): value, params, state_update_tensors = encoded_x part_decoded_value = decode_before_sum_tf_function(value, params) new_values = tf.nest.map_structure(tf.add, acc['values'], part_decoded_value) new_state_update_tensors = tf.nest.map_structure( _accmulate_state_update_tensor, acc['state_update_tensors'], state_update_tensors, state_update_aggregation_modes) return _accumulator_value(new_values, new_state_update_tensors) @computations.tf_computation(accumulator_type, accumulator_type) def merge_fn(acc1, acc2): new_values = tf.nest.map_structure(tf.add, acc1['values'], acc2['values']) new_state_update_tensors = tf.nest.map_structure( _accmulate_state_update_tensor, acc1['state_update_tensors'], acc2['state_update_tensors'], state_update_aggregation_modes) return _accumulator_value(new_values, new_state_update_tensors) @computations.tf_computation(accumulator_type) def report_fn(acc): return acc @computations.federated_computation( server_state_type, computation_types.at_clients(value_type)) def next_fn(state, value): encode_params, decode_before_sum_params, decode_after_sum_params = ( intrinsics.federated_map(get_params_fn, state)) encode_params = intrinsics.federated_broadcast(encode_params) decode_before_sum_params = intrinsics.federated_broadcast( decode_before_sum_params) encoded_values = intrinsics.federated_map( encode_fn, [value, encode_params, decode_before_sum_params]) aggregated_values = intrinsics.federated_aggregate( encoded_values, zero_fn(), accumulate_fn, merge_fn, report_fn) decoded_values = intrinsics.federated_map( decode_after_sum_fn, [aggregated_values.values, decode_after_sum_params]) updated_state = intrinsics.federated_map( update_state_fn, [state, aggregated_values.state_update_tensors]) empty_metrics = intrinsics.federated_value((), placements.SERVER) return measured_process.MeasuredProcessOutput( state=updated_state, result=decoded_values, measurements=empty_metrics) return next_fn
@computations.federated_computation() def test_init_fn(): return intrinsics.federated_value(0, placements.SERVER) test_state_type = test_init_fn.type_signature.result @computations.tf_computation def sum_sequence(s): spec = s.element_spec return s.reduce(tf.zeros(spec.shape, spec.dtype), lambda s, t: tf.nest.map_structure(tf.add, s, t)) ClientIntSequenceType = computation_types.at_clients( computation_types.SequenceType(tf.int32)) def build_next_fn(server_init_fn): @computations.federated_computation(server_init_fn.type_signature.result, ClientIntSequenceType) def next_fn(state, client_values): metrics = intrinsics.federated_map(sum_sequence, client_values) metrics = intrinsics.federated_sum(metrics) return LearningProcessOutput(state, metrics) return next_fn def build_report_fn(server_init_fn): @computations.tf_computation(server_init_fn.type_signature.result.member)
async def compute_intrinsic_federated_weighted_mean( executor: executor_base.Executor, arg: executor_value_base.ExecutorValue ) -> executor_value_base.ExecutorValue: """Computes a federated weighted mean on the given `executor`. Args: executor: The executor to use. arg: The argument to embedded in `executor`. Returns: The result embedded in `executor`. """ type_analysis.check_valid_federated_weighted_mean_argument_tuple_type( arg.type_signature) zip1_type = computation_types.FunctionType( computation_types.StructType([ computation_types.at_clients(arg.type_signature[0].member), computation_types.at_clients(arg.type_signature[1].member) ]), computation_types.at_clients( computation_types.StructType( [arg.type_signature[0].member, arg.type_signature[1].member]))) multiply_blk = building_block_factory.create_tensorflow_binary_operator_with_upcast( zip1_type.result.member, tf.multiply) map_type = computation_types.FunctionType( computation_types.StructType( [multiply_blk.type_signature, zip1_type.result]), computation_types.at_clients(multiply_blk.type_signature.result)) sum1_type = computation_types.FunctionType( computation_types.at_clients(map_type.result.member), computation_types.at_server(map_type.result.member)) sum2_type = computation_types.FunctionType( computation_types.at_clients(arg.type_signature[1].member), computation_types.at_server(arg.type_signature[1].member)) zip2_type = computation_types.FunctionType( computation_types.StructType([sum1_type.result, sum2_type.result]), computation_types.at_server( computation_types.StructType( [sum1_type.result.member, sum2_type.result.member]))) divide_blk = building_block_factory.create_tensorflow_binary_operator_with_upcast( zip2_type.result.member, tf.divide) async def _compute_multiply_fn(): return await executor.create_value(multiply_blk.proto, multiply_blk.type_signature) async def _compute_multiply_arg(): zip1_comp = create_intrinsic_comp( intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip1_type) zip_fn = await executor.create_value(zip1_comp, zip1_type) return await executor.create_call(zip_fn, arg) async def _compute_product_fn(): map_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_MAP, map_type) return await executor.create_value(map_comp, map_type) async def _compute_product_arg(): multiply_fn, multiply_arg = await asyncio.gather( _compute_multiply_fn(), _compute_multiply_arg()) return await executor.create_struct((multiply_fn, multiply_arg)) async def _compute_products(): product_fn, product_arg = await asyncio.gather(_compute_product_fn(), _compute_product_arg()) return await executor.create_call(product_fn, product_arg) async def _compute_total_weight(): sum2_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum2_type) sum2_fn, sum2_arg = await asyncio.gather( executor.create_value(sum2_comp, sum2_type), executor.create_selection(arg, index=1)) return await executor.create_call(sum2_fn, sum2_arg) async def _compute_sum_of_products(): sum1_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum1_type) sum1_fn, products = await asyncio.gather( executor.create_value(sum1_comp, sum1_type), _compute_products()) return await executor.create_call(sum1_fn, products) async def _compute_zip2_fn(): zip2_comp = create_intrinsic_comp( intrinsic_defs.FEDERATED_ZIP_AT_SERVER, zip2_type) return await executor.create_value(zip2_comp, zip2_type) async def _compute_zip2_arg(): sum_of_products, total_weight = await asyncio.gather( _compute_sum_of_products(), _compute_total_weight()) return await executor.create_struct([sum_of_products, total_weight]) async def _compute_divide_fn(): return await executor.create_value(divide_blk.proto, divide_blk.type_signature) async def _compute_divide_arg(): zip_fn, zip_arg = await asyncio.gather(_compute_zip2_fn(), _compute_zip2_arg()) return await executor.create_call(zip_fn, zip_arg) async def _compute_apply_fn(): apply_type = computation_types.FunctionType( computation_types.StructType( [divide_blk.type_signature, zip2_type.result]), computation_types.at_server(divide_blk.type_signature.result)) apply_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_APPLY, apply_type) return await executor.create_value(apply_comp, apply_type) async def _compute_apply_arg(): divide_fn, divide_arg = await asyncio.gather(_compute_divide_fn(), _compute_divide_arg()) return await executor.create_struct([divide_fn, divide_arg]) async def _compute_divided(): apply_fn, apply_arg = await asyncio.gather(_compute_apply_fn(), _compute_apply_arg()) return await executor.create_call(apply_fn, apply_arg) return await _compute_divided()
def _clipped_sum(clip=2.0): return clipping_factory.ClippingFactory(clip, sum_factory.SumFactory()) def _zeroed_mean(clip=2.0, norm_order=2.0): return clipping_factory.ZeroingFactory(clip, mean_factory.MeanFactory(), norm_order) def _zeroed_sum(clip=2.0, norm_order=2.0): return clipping_factory.ZeroingFactory(clip, sum_factory.SumFactory(), norm_order) _float_at_server = computation_types.at_server(tf.float32) _float_at_clients = computation_types.at_clients(tf.float32) @computations.federated_computation() def _test_init_fn(): return intrinsics.federated_value(1., placements.SERVER) @computations.federated_computation(_float_at_server, _float_at_clients) def _test_next_fn(state, value): del value return intrinsics.federated_map( computations.tf_computation(lambda x: x + 1., tf.float32), state) @computations.federated_computation(_float_at_server)