def get_example_cf_compatible_iterative_processes(): # pyformat: disable return ( ('sum_example', get_iterative_process_for_sum_example()), ('sum_example_with_no_prepare', get_iterative_process_for_sum_example_with_no_prepare()), ('sum_example_with_no_broadcast', get_iterative_process_for_sum_example_with_no_broadcast()), ('sum_example_with_no_federated_aggregate', get_iterative_process_for_sum_example_with_no_federated_aggregate()), ('sum_example_with_no_federated_secure_sum_bitwidth', get_iterative_process_for_sum_example_with_no_federated_secure_sum_bitwidth()), ('sum_example_with_no_update', get_iterative_process_for_sum_example_with_no_update()), ('sum_example_with_no_server_state', get_iterative_process_for_sum_example_with_no_server_state()), ('minimal_sum_example', get_iterative_process_for_minimal_sum_example()), ('example_with_unused_lambda_arg', mapreduce_test_utils.get_iterative_process_for_example_with_unused_lambda_arg()), ('example_with_unused_tf_computation_arg', mapreduce_test_utils.get_iterative_process_for_example_with_unused_tf_computation_arg()))
class GetCanonicalFormForIterativeProcessTest(CanonicalFormTestCase, parameterized.TestCase): def test_next_computation_returning_tensor_fails_well(self): cf = test_utils.get_temperature_sensor_example() it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) init_result = it.initialize.type_signature.result lam = building_blocks.Lambda('x', init_result, building_blocks.Reference('x', init_result)) bad_it = iterative_process.IterativeProcess( it.initialize, computation_wrapper_instances.building_block_to_computation(lam)) with self.assertRaises(TypeError): canonical_form_utils.get_canonical_form_for_iterative_process(bad_it) def test_broadcast_dependent_on_aggregate_fails_well(self): cf = test_utils.get_temperature_sensor_example() it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) next_comp = test_utils.computation_to_building_block(it.next) top_level_param = building_blocks.Reference(next_comp.parameter_name, next_comp.parameter_type) first_result = building_blocks.Call(next_comp, top_level_param) middle_param = building_blocks.Tuple([ building_blocks.Selection(first_result, index=0), building_blocks.Selection(top_level_param, index=1) ]) second_result = building_blocks.Call(next_comp, middle_param) not_reducible = building_blocks.Lambda(next_comp.parameter_name, next_comp.parameter_type, second_result) not_reducible_it = iterative_process.IterativeProcess( it.initialize, computation_wrapper_instances.building_block_to_computation( not_reducible)) with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'): canonical_form_utils.get_canonical_form_for_iterative_process( not_reducible_it) def test_constructs_canonical_form_from_mnist_training_example(self): it = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_mnist_training_example()) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) self.assertIsInstance(cf, canonical_form.CanonicalForm) def test_temperature_example_round_trip(self): it = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_temperature_sensor_example()) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) state = new_it.initialize() self.assertLen(state, 1) self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds']) self.assertEqual(state[0], 0) state, metrics, stats = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]]) self.assertLen(state, 1) self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds']) self.assertEqual(state[0], 1) self.assertLen(metrics, 1) self.assertAllEqual( anonymous_tuple.name_list(metrics), ['ratio_over_threshold']) self.assertEqual(metrics[0], 0.5) self.assertCountEqual([self.evaluate(x.num_readings) for x in stats], [1, 3]) state, metrics, stats = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]]) self.assertAllEqual(state, (2,)) self.assertAllClose(metrics, {'ratio_over_threshold': 0.75}) self.assertCountEqual([x.num_readings for x in stats], [1, 1, 1, 1]) self.assertEqual( tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(it.next)), tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(new_it.next))) def test_mnist_training_round_trip(self): it = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_mnist_training_example()) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) state1 = it.initialize() state2 = new_it.initialize() self.assertEqual(str(state1), str(state2)) dummy_x = np.array([[0.5] * 784], dtype=np.float32) dummy_y = np.array([1], dtype=np.int32) client_data = [collections.OrderedDict(x=dummy_x, y=dummy_y)] round_1 = it.next(state1, [client_data]) state = round_1[0] metrics = round_1[1] alt_round_1 = new_it.next(state2, [client_data]) alt_state = alt_round_1[0] alt_metrics = alt_round_1[1] self.assertAllEqual( anonymous_tuple.name_list(state), anonymous_tuple.name_list(alt_state)) self.assertAllEqual( anonymous_tuple.name_list(metrics), anonymous_tuple.name_list(alt_metrics)) self.assertAllClose(state, alt_state) self.assertAllClose(metrics, alt_metrics) self.assertEqual( tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(it.next)), tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(new_it.next))) # pyformat: disable @parameterized.named_parameters( ('sum_example', get_iterative_process_for_sum_example()), ('sum_example_with_no_prepare', get_iterative_process_for_sum_example_with_no_prepare()), ('sum_example_with_no_broadcast', get_iterative_process_for_sum_example_with_no_broadcast()), ('sum_example_with_no_client_output', get_iterative_process_for_sum_example_with_no_client_output()), ('sum_example_with_no_federated_aggregate', get_iterative_process_for_sum_example_with_no_federated_aggregate()), ('sum_example_with_no_federated_secure_sum', get_iterative_process_for_sum_example_with_no_federated_secure_sum()), ('sum_example_with_no_update', get_iterative_process_for_sum_example_with_no_update()), ('sum_example_with_no_server_state', get_iterative_process_for_sum_example_with_no_server_state()), ('minimal_sum_example', get_iterative_process_for_minimal_sum_example()), ('example_with_unused_lambda_arg', test_utils.get_iterative_process_for_example_with_unused_lambda_arg()), ('example_with_unused_tf_computation_arg', test_utils.get_iterative_process_for_example_with_unused_tf_computation_arg()), ) # pyformat: enable def test_returns_canonical_form(self, ip): cf = canonical_form_utils.get_canonical_form_for_iterative_process(ip) self.assertIsInstance(cf, canonical_form.CanonicalForm) def test_raises_value_error_for_sum_example_with_no_aggregation(self): ip = get_iterative_process_for_sum_example_with_no_aggregation() with self.assertRaises(ValueError): canonical_form_utils.get_canonical_form_for_iterative_process(ip)
class GetCanonicalFormForIterativeProcessTest(CanonicalFormTestCase, parameterized.TestCase): def test_next_computation_returning_tensor_fails_well(self): cf = test_utils.get_temperature_sensor_example() it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) init_result = it.initialize.type_signature.result lam = building_blocks.Lambda( 'x', init_result, building_blocks.Reference('x', init_result)) bad_it = iterative_process.IterativeProcess( it.initialize, computation_wrapper_instances.building_block_to_computation(lam)) with self.assertRaises(TypeError): canonical_form_utils.get_canonical_form_for_iterative_process( bad_it) def test_broadcast_dependent_on_aggregate_fails_well(self): cf = test_utils.get_temperature_sensor_example() it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) next_comp = test_utils.computation_to_building_block(it.next) top_level_param = building_blocks.Reference(next_comp.parameter_name, next_comp.parameter_type) first_result = building_blocks.Call(next_comp, top_level_param) middle_param = building_blocks.Tuple([ building_blocks.Selection(first_result, index=0), building_blocks.Selection(top_level_param, index=1) ]) second_result = building_blocks.Call(next_comp, middle_param) not_reducible = building_blocks.Lambda(next_comp.parameter_name, next_comp.parameter_type, second_result) not_reducible_it = iterative_process.IterativeProcess( it.initialize, computation_wrapper_instances.building_block_to_computation( not_reducible)) with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'): canonical_form_utils.get_canonical_form_for_iterative_process( not_reducible_it) def test_constructs_canonical_form_from_mnist_training_example(self): it = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_mnist_training_example()) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) self.assertIsInstance(cf, canonical_form.CanonicalForm) def test_temperature_example_round_trip(self): it = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_temperature_sensor_example()) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) new_it = canonical_form_utils.get_iterative_process_for_canonical_form( cf) state = new_it.initialize() self.assertLen(state, 1) self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds']) self.assertEqual(state[0], 0) state, metrics, stats = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]]) self.assertLen(state, 1) self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds']) self.assertEqual(state[0], 1) self.assertLen(metrics, 1) self.assertAllEqual(anonymous_tuple.name_list(metrics), ['ratio_over_threshold']) self.assertEqual(metrics[0], 0.5) self.assertCountEqual([self.evaluate(x.num_readings) for x in stats], [1, 3]) state, metrics, stats = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]]) self.assertAllEqual(state, (2, )) self.assertAllClose(metrics, {'ratio_over_threshold': 0.75}) self.assertCountEqual([x.num_readings for x in stats], [1, 1, 1, 1]) self.assertEqual( tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(it.next)), tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(new_it.next))) def test_mnist_training_round_trip(self): it = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_mnist_training_example()) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) new_it = canonical_form_utils.get_iterative_process_for_canonical_form( cf) state1 = it.initialize() state2 = new_it.initialize() self.assertEqual(str(state1), str(state2)) dummy_x = np.array([[0.5] * 784], dtype=np.float32) dummy_y = np.array([1], dtype=np.int32) client_data = [collections.OrderedDict(x=dummy_x, y=dummy_y)] round_1 = it.next(state1, [client_data]) state = round_1[0] metrics = round_1[1] alt_round_1 = new_it.next(state2, [client_data]) alt_state = alt_round_1[0] alt_metrics = alt_round_1[1] self.assertAllEqual(anonymous_tuple.name_list(state), anonymous_tuple.name_list(alt_state)) self.assertAllEqual(anonymous_tuple.name_list(metrics), anonymous_tuple.name_list(alt_metrics)) self.assertAllClose(state, alt_state) self.assertAllClose(metrics, alt_metrics) self.assertEqual( tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(it.next)), tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(new_it.next))) def test_canonical_form_from_tff_learning_structure_type_spec(self): it = test_utils.construct_example_training_comp() cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) work_type_spec = cf.work.type_signature # This type spec test actually carries the meaning that TFF's vanilla path # to canonical form will broadcast and aggregate exactly one copy of the # parameters. So the type test below in fact functions as a regression test # for the TFF compiler pipeline. # pyformat: disable expected_type_string = '(<<x=float32[?,2],y=int32[?,1]>*,<<trainable=<float32[2,1],float32[1]>,non_trainable=<>>>> -> <<<<<float32[2,1],float32[1]>,float32>,<float32,float32>,<float32,float32>,<float32>>,<>>,<>>)' # pyformat: enable self.assertEqual(work_type_spec.compact_representation(), expected_type_string) def test_returns_canonical_form_from_tff_learning_structure(self): it = test_utils.construct_example_training_comp() cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) new_it = canonical_form_utils.get_iterative_process_for_canonical_form( cf) self.assertIsInstance(cf, canonical_form.CanonicalForm) self.assertEqual(it.initialize.type_signature, new_it.initialize.type_signature) # Notice next type_signatures need not be equal, since we may have appended # an empty tuple as client side-channel outputs if none existed self.assertEqual(it.next.type_signature.parameter, new_it.next.type_signature.parameter) self.assertEqual(it.next.type_signature.result[0], new_it.next.type_signature.result[0]) self.assertEqual(it.next.type_signature.result[1], new_it.next.type_signature.result[1]) state1 = it.initialize() state2 = new_it.initialize() sample_batch = collections.OrderedDict(x=np.array([[1., 1.]], dtype=np.float32), y=np.array([[0]], dtype=np.int32)) client_data = [sample_batch] round_1 = it.next(state1, [client_data]) state = round_1[0] state_names = anonymous_tuple.name_list(state) state_arrays = anonymous_tuple.flatten(state) metrics = round_1[1] metrics_names = [x[0] for x in anonymous_tuple.iter_elements(metrics)] metrics_arrays = anonymous_tuple.flatten(metrics) alt_round_1 = new_it.next(state2, [client_data]) alt_state = alt_round_1[0] alt_state_names = anonymous_tuple.name_list(alt_state) alt_state_arrays = anonymous_tuple.flatten(alt_state) alt_metrics = alt_round_1[1] alt_metrics_names = [ x[0] for x in anonymous_tuple.iter_elements(alt_metrics) ] alt_metrics_arrays = anonymous_tuple.flatten(alt_metrics) self.assertEmpty(state.delta_aggregate_state) self.assertEmpty(state.model_broadcast_state) self.assertAllEqual(state_names, alt_state_names) self.assertAllEqual(metrics_names, alt_metrics_names) self.assertAllClose(state_arrays, alt_state_arrays) self.assertAllClose(metrics_arrays[:2], alt_metrics_arrays[:2]) # Final metric is execution time self.assertAlmostEqual(metrics_arrays[2], alt_metrics_arrays[2], delta=1e-3) # pyformat: disable @parameterized.named_parameters( ('sum_example', get_iterative_process_for_sum_example()), ('sum_example_with_no_prepare', get_iterative_process_for_sum_example_with_no_prepare()), ('sum_example_with_no_broadcast', get_iterative_process_for_sum_example_with_no_broadcast()), ('sum_example_with_no_client_output', get_iterative_process_for_sum_example_with_no_client_output()), ('sum_example_with_no_federated_aggregate', get_iterative_process_for_sum_example_with_no_federated_aggregate()), ('sum_example_with_no_federated_secure_sum', get_iterative_process_for_sum_example_with_no_federated_secure_sum()), ('sum_example_with_no_update', get_iterative_process_for_sum_example_with_no_update()), ('sum_example_with_no_server_state', get_iterative_process_for_sum_example_with_no_server_state()), ('minimal_sum_example', get_iterative_process_for_minimal_sum_example()), ('example_with_unused_lambda_arg', test_utils.get_iterative_process_for_example_with_unused_lambda_arg() ), ('example_with_unused_tf_computation_arg', test_utils. get_iterative_process_for_example_with_unused_tf_computation_arg()), ) # pyformat: enable def test_returns_canonical_form(self, ip): cf = canonical_form_utils.get_canonical_form_for_iterative_process(ip) self.assertIsInstance(cf, canonical_form.CanonicalForm) def test_raises_value_error_for_sum_example_with_no_aggregation(self): ip = get_iterative_process_for_sum_example_with_no_aggregation() with self.assertRaises(ValueError): canonical_form_utils.get_canonical_form_for_iterative_process(ip)
class GetCanonicalFormForIterativeProcessTest(CanonicalFormTestCase, parameterized.TestCase): def test_next_computation_returning_tensor_fails_well(self): cf = test_utils.get_temperature_sensor_example() it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) init_result = it.initialize.type_signature.result lam = building_blocks.Lambda( 'x', init_result, building_blocks.Reference('x', init_result)) bad_it = iterative_process.IterativeProcess( it.initialize, computation_wrapper_instances.building_block_to_computation(lam)) with self.assertRaises(TypeError): canonical_form_utils.get_canonical_form_for_iterative_process( bad_it) def test_broadcast_dependent_on_aggregate_fails_well(self): cf = test_utils.get_temperature_sensor_example() it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) next_comp = test_utils.computation_to_building_block(it.next) top_level_param = building_blocks.Reference(next_comp.parameter_name, next_comp.parameter_type) first_result = building_blocks.Call(next_comp, top_level_param) middle_param = building_blocks.Struct([ building_blocks.Selection(first_result, index=0), building_blocks.Selection(top_level_param, index=1) ]) second_result = building_blocks.Call(next_comp, middle_param) not_reducible = building_blocks.Lambda(next_comp.parameter_name, next_comp.parameter_type, second_result) not_reducible_it = iterative_process.IterativeProcess( it.initialize, computation_wrapper_instances.building_block_to_computation( not_reducible)) with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'): canonical_form_utils.get_canonical_form_for_iterative_process( not_reducible_it) def test_constructs_canonical_form_from_mnist_training_example(self): it = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_mnist_training_example()) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) self.assertIsInstance(cf, canonical_form.CanonicalForm) def test_temperature_example_round_trip(self): # NOTE: the roundtrip through CanonicalForm->IterProc->CanonicalForm seems # to lose the python container annotations on the StructType. it = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_temperature_sensor_example()) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) new_it = canonical_form_utils.get_iterative_process_for_canonical_form( cf) state = new_it.initialize() self.assertEqual(state.num_rounds, 0) state, metrics = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]]) self.assertEqual(state.num_rounds, 1) self.assertAllClose(metrics, collections.OrderedDict(ratio_over_threshold=0.5)) state, metrics = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]]) self.assertAllClose(metrics, collections.OrderedDict(ratio_over_threshold=0.75)) self.assertEqual( tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(it.next)), tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(new_it.next))) def test_mnist_training_round_trip(self): it = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_mnist_training_example()) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) new_it = canonical_form_utils.get_iterative_process_for_canonical_form( cf) state1 = it.initialize() state2 = new_it.initialize() self.assertAllClose(state1, state2) dummy_x = np.array([[0.5] * 784], dtype=np.float32) dummy_y = np.array([1], dtype=np.int32) client_data = [collections.OrderedDict(x=dummy_x, y=dummy_y)] round_1 = it.next(state1, [client_data]) state = round_1[0] metrics = round_1[1] alt_round_1 = new_it.next(state2, [client_data]) alt_state = alt_round_1[0] self.assertAllClose(state, alt_state) alt_metrics = alt_round_1[1] self.assertAllClose(metrics, alt_metrics) self.assertEqual( tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(it.next)), tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(new_it.next))) # pyformat: disable @parameterized.named_parameters( ('sum_example', get_iterative_process_for_sum_example()), ('sum_example_with_no_prepare', get_iterative_process_for_sum_example_with_no_prepare()), ('sum_example_with_no_broadcast', get_iterative_process_for_sum_example_with_no_broadcast()), ('sum_example_with_no_federated_aggregate', get_iterative_process_for_sum_example_with_no_federated_aggregate()), ('sum_example_with_no_federated_secure_sum', get_iterative_process_for_sum_example_with_no_federated_secure_sum()), ('sum_example_with_no_update', get_iterative_process_for_sum_example_with_no_update()), ('sum_example_with_no_server_state', get_iterative_process_for_sum_example_with_no_server_state()), ('minimal_sum_example', get_iterative_process_for_minimal_sum_example()), ('example_with_unused_lambda_arg', test_utils.get_iterative_process_for_example_with_unused_lambda_arg() ), ('example_with_unused_tf_computation_arg', test_utils. get_iterative_process_for_example_with_unused_tf_computation_arg()), ) # pyformat: enable def test_returns_canonical_form(self, ip): cf = canonical_form_utils.get_canonical_form_for_iterative_process(ip) self.assertIsInstance(cf, canonical_form.CanonicalForm) # pyformat: disable @parameterized.named_parameters( ('sum_example', get_iterative_process_for_sum_example()), ('sum_example_with_no_prepare', get_iterative_process_for_sum_example_with_no_prepare()), ('sum_example_with_no_broadcast', get_iterative_process_for_sum_example_with_no_broadcast()), ('sum_example_with_no_federated_aggregate', get_iterative_process_for_sum_example_with_no_federated_aggregate()), ('sum_example_with_no_federated_secure_sum', get_iterative_process_for_sum_example_with_no_federated_secure_sum()), ('sum_example_with_no_update', get_iterative_process_for_sum_example_with_no_update()), ('sum_example_with_no_server_state', get_iterative_process_for_sum_example_with_no_server_state()), ('minimal_sum_example', get_iterative_process_for_minimal_sum_example()), ('example_with_unused_lambda_arg', test_utils.get_iterative_process_for_example_with_unused_lambda_arg() ), ('example_with_unused_tf_computation_arg', test_utils. get_iterative_process_for_example_with_unused_tf_computation_arg()), ) # pyformat: enable def test_returns_canonical_form_with_grappler_disabled(self, ip): cf = canonical_form_utils.get_canonical_form_for_iterative_process( ip, None) self.assertIsInstance(cf, canonical_form.CanonicalForm) def test_raises_value_error_for_sum_example_with_no_aggregation(self): ip = get_iterative_process_for_sum_example_with_no_aggregation() with self.assertRaisesRegex( ValueError, r'Expected .* containing at least one `federated_aggregate` or ' r'`federated_secure_sum`'): canonical_form_utils.get_canonical_form_for_iterative_process(ip) def test_returns_canonical_form_with_indirection_to_intrinsic(self): self.skipTest('b/160865930') ip = test_utils.get_iterative_process_for_example_with_lambda_returning_aggregation( ) cf = canonical_form_utils.get_canonical_form_for_iterative_process(ip) self.assertIsInstance(cf, canonical_form.CanonicalForm)