def test_next_computation_returning_tensor_fails_well(self): mrf = mapreduce_test_utils.get_temperature_sensor_example() it = form_utils.get_iterative_process_for_map_reduce_form(mrf) init_result = it.initialize.type_signature.result lam = building_blocks.Lambda( 'x', init_result, building_blocks.Reference('x', init_result)) bad_it = iterative_process.IterativeProcess( it.initialize, computation_impl.ConcreteComputation.from_building_block(lam)) with self.assertRaises(TypeError): form_utils.get_map_reduce_form_for_iterative_process(bad_it)
def get_map_reduce_form_for_client_to_server_fn( self, client_to_server_fn) -> forms.MapReduceForm: """Produces a `MapReduceForm` for the provided `client_to_server_fn`. Creates an `iterative_process.IterativeProcess` which uses `client_to_server_fn` to map from `client_data` to `server_output`, then passes this value through `get_map_reduce_form_for_iterative_process`. Args: client_to_server_fn: A function from client-placed data to server-placed output. Returns: A `forms.MapReduceForm` which uses the embedded `client_to_server_fn`. """ @federated_computation.federated_computation def init_fn(): return intrinsics.federated_value((), placements.SERVER) @federated_computation.federated_computation([ computation_types.at_server(()), computation_types.at_clients(tf.int32), ]) def next_fn(server_state, client_data): server_output = client_to_server_fn(client_data) return server_state, server_output ip = iterative_process.IterativeProcess(init_fn, next_fn) return form_utils.get_map_reduce_form_for_iterative_process(ip)
def test_returns_map_reduce_form_with_indirection_to_intrinsic(self): ip = mapreduce_test_utils.get_iterative_process_for_example_with_lambda_returning_aggregation( ) mrf = form_utils.get_map_reduce_form_for_iterative_process(ip) self.assertIsInstance(mrf, forms.MapReduceForm)
def test_returns_canonical_form_with_grappler_disabled(self, ip): grappler_config = tf.compat.v1.ConfigProto() grappler_config.graph_options.rewrite_options.disable_meta_optimizer = True mrf = form_utils.get_map_reduce_form_for_iterative_process( ip, grappler_config) self.assertIsInstance(mrf, forms.MapReduceForm)
def test_mnist_training_round_trip(self): it = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_mnist_training_example()) # TODO(b/208887729): We disable grappler to work around attempting to hoist # transformed functions of the same name into the eager context. When this # execution is C++-backed, this can go away. grappler_config = tf.compat.v1.ConfigProto() grappler_config.graph_options.rewrite_options.disable_meta_optimizer = True mrf = form_utils.get_map_reduce_form_for_iterative_process( it, grappler_config) new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf) state1 = it.initialize() state2 = new_it.initialize() self.assertAllClose(state1, state2) whimsy_x = np.array([[0.5] * 784], dtype=np.float32) whimsy_y = np.array([1], dtype=np.int32) client_data = [collections.OrderedDict(x=whimsy_x, y=whimsy_y)] round_1 = it.next(state1, [client_data]) state = round_1[0] metrics = round_1[1] alt_round_1 = new_it.next(state2, [client_data]) alt_state = alt_round_1[0] self.assertAllClose(state, alt_state) alt_metrics = alt_round_1[1] self.assertAllClose(metrics, alt_metrics) self.assertEqual( tree_analysis.count_tensorflow_variables_under( it.next.to_building_block()), tree_analysis.count_tensorflow_variables_under( new_it.next.to_building_block()))
def _check_aggregated_scalar_count(self, aggregator, max_scalars, min_scalars=0): aggregator = _mrfify_aggregator(aggregator) mrf = form_utils.get_map_reduce_form_for_iterative_process(aggregator) num_aggregated_scalars = type_analysis.count_tensors_in_type( mrf.work.type_signature.result)['parameters'] self.assertLess(num_aggregated_scalars, max_scalars) self.assertGreaterEqual(num_aggregated_scalars, min_scalars) return mrf
def test_broadcast_dependent_on_aggregate_fails_well(self): mrf = mapreduce_test_utils.get_temperature_sensor_example() it = form_utils.get_iterative_process_for_map_reduce_form(mrf) next_comp = it.next.to_building_block() top_level_param = building_blocks.Reference(next_comp.parameter_name, next_comp.parameter_type) first_result = building_blocks.Call(next_comp, top_level_param) middle_param = building_blocks.Struct([ building_blocks.Selection(first_result, index=0), building_blocks.Selection(top_level_param, index=1) ]) second_result = building_blocks.Call(next_comp, middle_param) not_reducible = building_blocks.Lambda(next_comp.parameter_name, next_comp.parameter_type, second_result) not_reducible_it = iterative_process.IterativeProcess( it.initialize, computation_wrapper_instances.building_block_to_computation( not_reducible)) with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'): form_utils.get_map_reduce_form_for_iterative_process(not_reducible_it)
def test_temperature_example_round_trip(self): # NOTE: the roundtrip through MapReduceForm->IterProc->MapReduceForm seems # to lose the python container annotations on the StructType. it = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_temperature_sensor_example()) mrf = form_utils.get_map_reduce_form_for_iterative_process(it) new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf) state = new_it.initialize() self.assertEqual(state['num_rounds'], 0) state, metrics = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]]) self.assertEqual(state['num_rounds'], 1) self.assertAllClose(metrics, collections.OrderedDict(ratio_over_threshold=0.5)) state, metrics = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]]) self.assertAllClose(metrics, collections.OrderedDict(ratio_over_threshold=0.75)) self.assertEqual( tree_analysis.count_tensorflow_variables_under( it.next.to_building_block()), tree_analysis.count_tensorflow_variables_under( new_it.next.to_building_block()))
def test_mnist_training_round_trip(self): it = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_mnist_training_example()) mrf = form_utils.get_map_reduce_form_for_iterative_process(it) new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf) state1 = it.initialize() state2 = new_it.initialize() self.assertAllClose(state1, state2) whimsy_x = np.array([[0.5] * 784], dtype=np.float32) whimsy_y = np.array([1], dtype=np.int32) client_data = [collections.OrderedDict(x=whimsy_x, y=whimsy_y)] round_1 = it.next(state1, [client_data]) state = round_1[0] metrics = round_1[1] alt_round_1 = new_it.next(state2, [client_data]) alt_state = alt_round_1[0] self.assertAllClose(state, alt_state) alt_metrics = alt_round_1[1] self.assertAllClose(metrics, alt_metrics) self.assertEqual( tree_analysis.count_tensorflow_variables_under( it.next.to_building_block()), tree_analysis.count_tensorflow_variables_under( new_it.next.to_building_block()))
def test_returns_map_reduce_form_for_sum_example_with_no_aggregation(self): ip = get_iterative_process_for_sum_example_with_no_aggregation() mrf = form_utils.get_map_reduce_form_for_iterative_process(ip) self.assertIsInstance(mrf, forms.MapReduceForm)
def test_returns_map_reduce_form(self, ip): mrf = form_utils.get_map_reduce_form_for_iterative_process(ip) self.assertIsInstance(mrf, forms.MapReduceForm)
def test_constructs_map_reduce_form_from_mnist_training_example(self): it = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_mnist_training_example()) mrf = form_utils.get_map_reduce_form_for_iterative_process(it) self.assertIsInstance(mrf, forms.MapReduceForm)
def test_gets_map_reduce_form_for_nested_broadcast(self): ip = get_iterative_process_with_nested_broadcasts() mrf = form_utils.get_map_reduce_form_for_iterative_process(ip) self.assertIsInstance(mrf, forms.MapReduceForm)