def test_iterative_process_type_mismatch(self): with self.assertRaisesRegex( TypeError, r'The return type of initialize_fn should match.*'): @computations.federated_computation([tf.float32, tf.float32]) def add_float32(current, val): return current + val _ = computation_utils.IterativeProcess(initialize_fn=initialize, next_fn=add_float32) with self.assertRaisesRegex( TypeError, 'The return type of next_fn should match the first parameter'): @computations.federated_computation(tf.int32) def add_bad_result(_): return 0.0 _ = computation_utils.IterativeProcess(initialize_fn=initialize, next_fn=add_bad_result) with self.assertRaisesRegex( TypeError, 'The return type of next_fn should match the first parameter'): @computations.federated_computation(tf.int32) def add_bad_multi_result(_): return 0.0, 0 _ = computation_utils.IterativeProcess( initialize_fn=initialize, next_fn=add_bad_multi_result)
def test_iterative_process_initialize_bad_type(self): with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'): _ = computation_utils.IterativeProcess( initialize_fn=None, next_fn=add_int32) with self.assertRaisesRegex( TypeError, r'initialize_fn must be a no-arg tff.Computation'): @tff.federated_computation(tf.int32) def one_arg_initialize(one_arg): del one_arg # unused return tff.to_value(0) _ = computation_utils.IterativeProcess( initialize_fn=one_arg_initialize, next_fn=add_int32)
def get_iterative_process_for_sum_example_with_no_server_state(): """Returns an iterative process for a sum example.""" @computations.federated_computation def init_fn(): """The `init` function for `computation_utils.IterativeProcess`.""" return intrinsics.federated_value([], placements.SERVER) @computations.tf_computation(tf.int32) def work(client_data): del client_data # Unused return [1, 1], [] @computations.tf_computation([tf.int32, tf.int32]) def update(global_update): return [], global_update @computations.federated_computation([ computation_types.FederatedType([], placements.SERVER), computation_types.FederatedType(tf.int32, placements.CLIENTS), ]) def next_fn(server_state, client_data): """The `next` function for `computation_utils.IterativeProcess`.""" del server_state # Unused client_updates, client_output = intrinsics.federated_map( work, client_data) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum(client_updates[1], 8) s5 = intrinsics.federated_zip([unsecure_update, secure_update]) new_server_state, server_output = intrinsics.federated_map(update, s5) return new_server_state, server_output, client_output return computation_utils.IterativeProcess(init_fn, next_fn)
def get_iterative_process_for_concise_sum_example(): """Returns an iterative process for a sum example.""" @computations.federated_computation def init_fn(): """The `init` function for `computation_utils.IterativeProcess`.""" return intrinsics.federated_value([0, 0], placements.SERVER) @computations.tf_computation(tf.int32, [tf.int32, tf.int32]) def work(client_data, client_input): del client_data # Unused del client_input # Unused return [1, 1] @computations.federated_computation([ computation_types.FederatedType([tf.int32, tf.int32], placements.SERVER), computation_types.FederatedType(tf.int32, placements.CLIENTS), ]) def next_fn(server_state, client_data): """The `next` function for `computation_utils.IterativeProcess`.""" client_input = intrinsics.federated_broadcast(server_state) c3 = intrinsics.federated_zip([client_data, client_input]) client_updates = intrinsics.federated_map(work, c3) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum(client_updates[1], 8) new_server_state = intrinsics.federated_zip( [unsecure_update, secure_update]) server_output = intrinsics.federated_value([], placements.SERVER) return new_server_state, server_output return computation_utils.IterativeProcess(init_fn, next_fn)
def get_iterative_process_for_canonical_form_example(): """Construct a simple `IterativeProcess` compatible with `CanonicalForm`. The computation itself is non-sensical; but demonstrates the required type signatures for `CanonicalForm```. Returns: An `IterativeProcess` compatible with `CanonicalForm`. """ @computations.tf_computation(tf.int32, tf.float32) def add_two(x_int, y_float): return tf.cast(x_int, tf.float32) + y_float @computations.federated_computation def init_fn(): return intrinsics.federated_value(1.234, placements.SERVER) @computations.federated_computation([ computation_types.FederatedType(tf.float32, placements.SERVER), computation_types.FederatedType(tf.int32, placements.CLIENTS) ]) def next_fn(server_val, client_val): """Defines a series of federated computations compatible with CanonicalForm.""" broadcast_val = intrinsics.federated_broadcast(server_val) values_on_clients = intrinsics.federated_zip( (client_val, broadcast_val)) result_on_clients = intrinsics.federated_map(add_two, values_on_clients) aggregated_result = intrinsics.federated_mean(result_on_clients) side_output = intrinsics.federated_value([1, 2, 3, 4, 5], placements.SERVER) return aggregated_result, side_output return computation_utils.IterativeProcess(init_fn, next_fn)
def test_returns_canonical_form_with_no_broadcast(self): @computations.tf_computation(tf.int32) @tf.function def map_fn(client_val): del client_val # unused return 1 @computations.federated_computation def init_fn(): return intrinsics.federated_value(False, placements.SERVER) @computations.federated_computation( computation_types.FederatedType(tf.bool, placements.SERVER), computation_types.FederatedType(tf.int32, placements.CLIENTS)) def next_fn(server_val, client_val): del server_val # Unused result_on_clients = intrinsics.federated_map(map_fn, client_val) aggregated_result = intrinsics.federated_sum(result_on_clients) side_output = intrinsics.federated_value(False, placements.SERVER) return side_output, aggregated_result ip = computation_utils.IterativeProcess(init_fn, next_fn) cf = canonical_form_utils.get_canonical_form_for_iterative_process(ip) self.assertIsInstance(cf, canonical_form.CanonicalForm)
def test_iterative_process_state_tuple_arg(self): iterative_process = computation_utils.IterativeProcess( initialize, add_int32) state = iterative_process.initialize() iterations = 10 for val in range(iterations): state = iterative_process.next(state, val) self.assertEqual(state, sum(range(iterations)))
def test_iterative_process_state_multiple_return_values(self): iterative_process = computation_utils.IterativeProcess( initialize, add_mul_int32) state = iterative_process.initialize() iterations = 10 for val in range(iterations): state, product = iterative_process.next(state, val) self.assertEqual(state, sum(range(iterations))) self.assertEqual(product, sum(range(iterations - 1)) * (iterations - 1))
def test_iterative_process_state_only(self): iterative_process = computation_utils.IterativeProcess( initialize, count_int32) state = iterative_process.initialize() iterations = 10 for _ in range(iterations): # TODO(b/122321354): remove the .item() call on `state` once numpy.int32 # type is supported. state = iterative_process.next(state.item()) self.assertEqual(state, iterations)
def test_tensor_computation_fails_well(self): cf = test_utils.get_temperature_sensor_example() it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) init_result = it.initialize.type_signature.result lam = building_blocks.Lambda('x', init_result, building_blocks.Reference('x', init_result)) bad_it = computation_utils.IterativeProcess( it.initialize, computation_wrapper_instances.building_block_to_computation(lam)) with self.assertRaisesRegex(TypeError, 'instances of `tff.NamedTupleType`.'): canonical_form_utils.get_canonical_form_for_iterative_process(bad_it)
def get_iterative_process_for_canonical_form(cf): """Creates `tff.utils.IterativeProcess` from a canonical form. Args: cf: An instance of `tff.backends.mapreduce.CanonicalForm`. Returns: An instance of `tff.utils.IterativeProcess` that corresponds to `cf`. Raises: TypeError: If the arguments are of the wrong types. """ py_typecheck.check_type(cf, canonical_form.CanonicalForm) @computations.federated_computation def init_computation(): return intrinsics.federated_value(cf.initialize(), placements.SERVER) @computations.federated_computation( init_computation.type_signature.result, computation_types.FederatedType(cf.work.type_signature.parameter[0], placements.CLIENTS)) def next_computation(arg): """The logic of a single MapReduce processing round.""" s1 = arg[0] c1 = arg[1] s2 = intrinsics.federated_map(cf.prepare, s1) c2 = intrinsics.federated_broadcast(s2) c3 = intrinsics.federated_zip([c1, c2]) c4 = intrinsics.federated_map(cf.work, c3) c5 = c4[0] c6 = c5[0] c7 = c5[1] c8 = c4[1] s3 = intrinsics.federated_aggregate(c6, cf.zero(), cf.accumulate, cf.merge, cf.report) s4 = intrinsics.federated_secure_sum(c7, cf.bitwidth()) s5 = intrinsics.federated_zip([s3, s4]) s6 = intrinsics.federated_zip([s1, s5]) s7 = intrinsics.federated_map(cf.update, s6) s8 = s7[0] s9 = s7[1] return s8, s9, c8 return computation_utils.IterativeProcess(init_computation, next_computation)
def get_unused_tf_computation_arg_iterative_process(): """Returns an iterative process with a @tf.function with an unused arg.""" server_state_type = computation_types.NamedTupleType([('num_clients', tf.int32)]) def _bind_tf_function(unused_input, tf_func): tf_wrapper = tf.function(lambda _: tf_func()) input_federated_type = unused_input.type_signature wrapper = computations.tf_computation(tf_wrapper, input_federated_type.member) return intrinsics.federated_map(wrapper, unused_input) def count_clients_federated(client_data): @tf.function def client_ones_fn(): return tf.ones(shape=[], dtype=tf.int32) client_ones = _bind_tf_function(client_data, client_ones_fn) return intrinsics.federated_sum(client_ones) @computations.federated_computation def init_fn(): return intrinsics.federated_value( collections.OrderedDict([('num_clients', 0)]), placements.SERVER) @computations.federated_computation([ computation_types.FederatedType(server_state_type, placements.SERVER), computation_types.FederatedType( computation_types.SequenceType(tf.string), placements.CLIENTS) ]) def next_fn(server_state, client_val): """`next` function for `computation_utils.IterativeProcess`.""" server_update = intrinsics.federated_zip( collections.OrderedDict([('num_clients', count_clients_federated(client_val))])) server_output = intrinsics.federated_value((), placements.SERVER) server_output = intrinsics.federated_sum( _bind_tf_function( intrinsics.federated_broadcast(server_state), tf.timestamp)) return server_update, server_output return computation_utils.IterativeProcess(init_fn, next_fn)
def get_unused_lambda_arg_iterative_process(): """Returns an iterative process having a Lambda not referencing its arg.""" server_state_type = computation_types.NamedTupleType([('num_clients', tf.int32)]) def _bind_federated_value(unused_input, input_type, federated_output_value): federated_input_type = computation_types.FederatedType( input_type, placements.CLIENTS) wrapper = computations.federated_computation( lambda _: federated_output_value, federated_input_type) return wrapper(unused_input) def count_clients_federated(client_data): client_ones = intrinsics.federated_value(1, placements.CLIENTS) client_ones = _bind_federated_value( client_data, computation_types.SequenceType(tf.string), client_ones) return intrinsics.federated_sum(client_ones) @computations.federated_computation def init_fn(): return intrinsics.federated_value( collections.OrderedDict([('num_clients', 0)]), placements.SERVER) @computations.federated_computation([ computation_types.FederatedType(server_state_type, placements.SERVER), computation_types.FederatedType( computation_types.SequenceType(tf.string), placements.CLIENTS) ]) def next_fn(server_state, client_val): """`next` function for `computation_utils.IterativeProcess`.""" server_update = intrinsics.federated_zip( collections.OrderedDict([('num_clients', count_clients_federated(client_val))])) server_output = intrinsics.federated_value((), placements.SERVER) server_output = _bind_federated_value( intrinsics.federated_broadcast(server_state), server_state_type, server_output) return server_update, server_output return computation_utils.IterativeProcess(init_fn, next_fn)
def get_iterative_process_for_canonical_form(cf): """Creates `tff.utils.IterativeProcess` from a canonical form. Args: cf: An instance of `tff.backends.mapreduce.CanonicalForm`. Returns: An instance of `tff.utils.IterativeProcess` that corresponds to `cf`. Raises: TypeError: If the arguments are of the wrong types. """ py_typecheck.check_type(cf, canonical_form.CanonicalForm) @tff.federated_computation def init_computation(): return tff.federated_value(cf.initialize(), tff.SERVER) @tff.federated_computation(init_computation.type_signature.result, tff.FederatedType( cf.work.type_signature.parameter[0], tff.CLIENTS)) def next_computation(arg): """The logic of a single MapReduce sprocessing round.""" s1 = arg[0] c1 = arg[1] s2 = tff.federated_apply(cf.prepare, s1) c2 = tff.federated_broadcast(s2) c3 = tff.federated_zip([c1, c2]) c4 = tff.federated_map(cf.work, c3) c5 = c4[0] c6 = c4[1] s3 = tff.federated_aggregate(c5, cf.zero(), cf.accumulate, cf.merge, cf.report) s4 = tff.federated_zip([s1, s3]) s5 = tff.federated_apply(cf.update, s4) s6 = s5[0] s7 = s5[1] return s6, s7, c6 return computation_utils.IterativeProcess(init_computation, next_computation)
def get_iterative_process_for_sum_example_with_no_aggregation(): """Returns an iterative process for a sum example.""" @computations.federated_computation def init_fn(): """The `init` function for `computation_utils.IterativeProcess`.""" return intrinsics.federated_value([0, 0], placements.SERVER) @computations.tf_computation([tf.int32, tf.int32]) def prepare(server_state): return server_state @computations.tf_computation(tf.int32, [tf.int32, tf.int32]) def work(client_data, client_input): del client_data # Unused del client_input # Unused return [1, 1], [] @computations.tf_computation([tf.int32, tf.int32], [tf.int32, tf.int32]) def update(server_state, global_update): del server_state # Unused return global_update, [] @computations.federated_computation([ computation_types.FederatedType([tf.int32, tf.int32], placements.SERVER), computation_types.FederatedType(tf.int32, placements.CLIENTS), ]) def next_fn(server_state, client_data): """The `next` function for `computation_utils.IterativeProcess`.""" s2 = intrinsics.federated_map(prepare, server_state) client_input = intrinsics.federated_broadcast(s2) c3 = intrinsics.federated_zip([client_data, client_input]) _, client_output = intrinsics.federated_map(work, c3) unsecure_update = intrinsics.federated_value(1, placements.SERVER) secure_update = intrinsics.federated_value(1, placements.SERVER) s6 = intrinsics.federated_zip( [server_state, [unsecure_update, secure_update]]) new_server_state, server_output = intrinsics.federated_map(update, s6) return new_server_state, server_output, client_output return computation_utils.IterativeProcess(init_fn, next_fn)
def test_returns_canonical_form_with_next_fn_returning_call_directly(self): @computations.federated_computation def init_fn(): return intrinsics.federated_value(42, placements.SERVER) @computations.federated_computation( computation_types.FederatedType(tf.int32, placements.SERVER), computation_types.FederatedType( computation_types.SequenceType(tf.float32), placements.CLIENTS)) def next_fn(server_state, client_data): broadcast_state = intrinsics.federated_broadcast(server_state) @computations.tf_computation(tf.int32, computation_types.SequenceType( tf.float32)) @tf.function def some_transform(x, y): del y # Unused return x + 1 client_update = intrinsics.federated_map( some_transform, (broadcast_state, client_data)) aggregate_update = intrinsics.federated_sum(client_update) server_output = intrinsics.federated_value(1234, placements.SERVER) return aggregate_update, server_output @computations.federated_computation( computation_types.FederatedType(tf.int32, placements.SERVER), computation_types.FederatedType( computation_types.SequenceType(tf.float32), placements.CLIENTS)) def nested_next_fn(server_state, client_data): return next_fn(server_state, client_data) iterative_process = computation_utils.IterativeProcess( init_fn, nested_next_fn) cf = canonical_form_utils.get_canonical_form_for_iterative_process( iterative_process) self.assertIsInstance(cf, canonical_form.CanonicalForm)
def test_broadcast_dependent_on_aggregate_fails_well(self): cf = test_utils.get_temperature_sensor_example() it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) next_comp = test_utils.computation_to_building_block(it.next) top_level_param = building_blocks.Reference(next_comp.parameter_name, next_comp.parameter_type) first_result = building_blocks.Call(next_comp, top_level_param) middle_param = building_blocks.Tuple([ building_blocks.Selection(first_result, index=0), building_blocks.Selection(top_level_param, index=1) ]) second_result = building_blocks.Call(next_comp, middle_param) not_reducible = building_blocks.Lambda(next_comp.parameter_name, next_comp.parameter_type, second_result) not_reducible_it = computation_utils.IterativeProcess( it.initialize, computation_wrapper_instances.building_block_to_computation( not_reducible)) with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'): canonical_form_utils.get_canonical_form_for_iterative_process( not_reducible_it)
def test_iterative_process_next_bad_type(self): with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'): _ = computation_utils.IterativeProcess(initialize_fn=initialize, next_fn=None)