def check_iterative_process_compatible_with_map_reduce_form( ip: iterative_process.IterativeProcess): """Tests compatibility with `tff.backends.mapreduce.MapReduceForm`. Note: the conditions here are specified in the documentation for `get_map_reduce_form_for_iterative_process`. Changes to this function should be propagated to that documentation. Args: ip: An instance of `tff.templates.IterativeProcess` to check for compatibility with `tff.backends.mapreduce.MapReduceForm`. Returns: TFF-internal building-blocks representing the validated and simplified `initialize` and `next` computations. Raises: TypeError: If the arguments are of the wrong types. """ py_typecheck.check_type(ip, iterative_process.IterativeProcess) initialize_tree = ip.initialize.to_building_block() next_tree = ip.next.to_building_block() init_type = initialize_tree.type_signature _check_type_is_no_arg_fn(init_type, '`initialize`', TypeError) if (not init_type.result.is_federated() or init_type.result.placement != placements.SERVER): raise TypeError( 'Expected `initialize` to return a single federated value ' 'placed at server (type `T@SERVER`), found return type:\n' f'{init_type.result}') next_type = next_tree.type_signature _check_type_is_fn(next_type, '`next`', TypeError) if not next_type.parameter.is_struct() or len(next_type.parameter) != 2: raise TypeError( 'Expected `next` to take two arguments, found parameter ' f' type:\n{next_type.parameter}') if not next_type.result.is_struct() or len(next_type.result) != 2: raise TypeError('Expected `next` to return two values, found result ' f'type:\n{next_type.result}') initialize_tree, _ = intrinsic_reductions.replace_intrinsics_with_bodies( initialize_tree) next_tree, _ = intrinsic_reductions.replace_intrinsics_with_bodies( next_tree) next_tree = _replace_lambda_body_with_call_dominant_form(next_tree) tree_analysis.check_contains_only_reducible_intrinsics(initialize_tree) tree_analysis.check_contains_only_reducible_intrinsics(next_tree) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_tree) return initialize_tree, next_tree
def test_federated_secure_select(self): uri = intrinsic_defs.FEDERATED_SECURE_SELECT.uri comp = building_blocks.Intrinsic( uri, computation_types.FunctionType( [ computation_types.at_clients(tf.int32), # client_keys computation_types.at_server(tf.int32), # max_key computation_types.at_server(tf.float32), # server_state computation_types.FunctionType([tf.float32, tf.int32], tf.float32) # select_fn ], computation_types.at_clients( computation_types.SequenceType(tf.float32)))) self.assertGreater(_count_intrinsics(comp, uri), 0) # First without secure intrinsics shouldn't modify anything. reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies( comp) self.assertFalse(modified) self.assertGreater(_count_intrinsics(comp, uri), 0) self.assert_types_identical(comp.type_signature, reduced.type_signature) # Now replace bodies including secure intrinsics. reduced, modified = intrinsic_reductions.replace_secure_intrinsics_with_insecure_bodies( comp) self.assertTrue(modified) self.assert_types_identical(comp.type_signature, reduced.type_signature) self.assertGreater( _count_intrinsics(reduced, intrinsic_defs.FEDERATED_SELECT.uri), 0)
def test_federated_secure_sum_bitwidth(self, value_dtype, bitwidth_type): uri = intrinsic_defs.FEDERATED_SECURE_SUM_BITWIDTH.uri comp = building_blocks.Intrinsic( uri, computation_types.FunctionType( parameter=[ computation_types.at_clients(value_dtype), computation_types.to_type(bitwidth_type) ], result=computation_types.at_server(value_dtype))) # First without secure intrinsics shouldn't modify anything. reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies( comp) self.assertFalse(modified) self.assertGreater(_count_intrinsics(comp, uri), 0) self.assert_types_identical(comp.type_signature, reduced.type_signature) # Now replace bodies including secure intrinsics. reduced, modified = intrinsic_reductions.replace_secure_intrinsics_with_insecure_bodies( comp) self.assertTrue(modified) self.assert_types_identical(comp.type_signature, reduced.type_signature) self.assertGreater( _count_intrinsics(reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri), 0)
def test_federated_mean_reduces_to_aggregate(self): uri = intrinsic_defs.FEDERATED_MEAN.uri @computations.federated_computation( computation_types.FederatedType(tf.float32, placement_literals.CLIENTS)) def foo(x): return intrinsics.federated_mean(x) foo_building_block = building_blocks.ComputationBuildingBlock.from_proto( foo._computation_proto) count_means_before_reduction = _count_intrinsics( foo_building_block, uri) reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies( foo_building_block) count_means_after_reduction = _count_intrinsics(reduced, uri) count_aggregations = _count_intrinsics( reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri) count_means_after_reduction = _count_intrinsics(reduced, uri) self.assertTrue(modified) self.assert_types_identical(foo_building_block.type_signature, reduced.type_signature) self.assertGreater(count_means_before_reduction, 0) self.assertEqual(count_means_after_reduction, 0) self.assertGreater(count_aggregations, 0)
def test_federated_reduce_not_reduced(self): # It is in general unsafe to attempt to decompose a federated_reduce into a # federated_aggregate, as not every reduction function can be decomposed # into accumulate and merge calls without writing fundamentally new logic, # e.g. 'return 1 if both arguments are 0'. uri = intrinsic_defs.FEDERATED_REDUCE.uri @computations.tf_computation(tf.float32, tf.float32) def add(x, y): return x + y @computations.federated_computation( computation_types.FederatedType(tf.float32, placement_literals.CLIENTS)) def foo(x): return intrinsics.federated_reduce(x, 0., add) foo_building_block = building_blocks.ComputationBuildingBlock.from_proto( foo._computation_proto) count_reduce_before_reduction = _count_intrinsics( foo_building_block, uri) reduced, _ = intrinsic_reductions.replace_intrinsics_with_bodies( foo_building_block) self.assert_types_identical(foo_building_block.type_signature, reduced.type_signature) count_reduce_after_reduction = _count_intrinsics(reduced, uri) self.assertGreater(count_reduce_before_reduction, 0) self.assertEqual(count_reduce_after_reduction, count_reduce_before_reduction)
def get_broadcast_form_for_computation( comp: computation_base.Computation, grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG ) -> forms.BroadcastForm: """Constructs `tff.backends.mapreduce.BroadcastForm` given a computation. Args: comp: An instance of `tff.Computation` that is compatible with broadcast form. Computations are only compatible if they take in a single value placed at server, return a single value placed at clients, and do not contain any aggregations. grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the Tensorflow graphs backing the resulting `tff.backends.mapreduce.BroadcastForm`. These options are combined with a set of defaults that aggressively configure Grappler. If `grappler_config_proto` has `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is bypassed. Returns: An instance of `tff.backends.mapreduce.BroadcastForm` equivalent to the provided `tff.Computation`. """ py_typecheck.check_type(comp, computation_base.Computation) _check_function_signature_compatible_with_broadcast_form( comp.type_signature) py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto) grappler_config = _merge_grappler_config_with_default(grappler_config) bb = comp.to_building_block() bb, _ = intrinsic_reductions.replace_intrinsics_with_bodies(bb) bb = _replace_lambda_body_with_call_dominant_form(bb) tree_analysis.check_contains_only_reducible_intrinsics(bb) aggregations = tree_analysis.find_aggregations_in_tree(bb) if aggregations: raise ValueError( f'`get_broadcast_form_for_computation` called with computation ' f'containing {len(aggregations)} aggregations, but broadcast form ' 'does not allow aggregation. Full list of aggregations:\n{aggregations}' ) before_broadcast, after_broadcast = _split_ast_on_broadcast(bb) compute_server_context = _extract_compute_server_context( before_broadcast, grappler_config) client_processing = _extract_client_processing(after_broadcast, grappler_config) compute_server_context, client_processing = ( computation_wrapper_instances.building_block_to_computation(bb) for bb in (compute_server_context, client_processing)) comp_param_names = structure.name_list_with_nones( comp.type_signature.parameter) server_data_label, client_data_label = comp_param_names return forms.BroadcastForm(compute_server_context, client_processing, server_data_label=server_data_label, client_data_label=client_data_label)
def test_generic_plus_reduces(self): uri = intrinsic_defs.GENERIC_PLUS.uri comp = building_blocks.Intrinsic( uri, computation_types.FunctionType([tf.float32, tf.float32], tf.float32)) count_before_reduction = _count_intrinsics(comp, uri) reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies( comp) count_after_reduction = _count_intrinsics(reduced, uri) self.assertTrue(modified) self.assert_types_identical(comp.type_signature, reduced.type_signature) self.assertGreater(count_before_reduction, 0) self.assertEqual(count_after_reduction, 0) tree_analysis.check_contains_only_reducible_intrinsics(reduced)
def test_federated_sum_reduces_to_aggregate(self): uri = intrinsic_defs.FEDERATED_SUM.uri comp = building_blocks.Intrinsic( uri, computation_types.FunctionType( computation_types.at_clients(tf.float32), computation_types.at_server(tf.float32))) count_sum_before_reduction = _count_intrinsics(comp, uri) reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies( comp) count_sum_after_reduction = _count_intrinsics(reduced, uri) count_aggregations = _count_intrinsics( reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri) self.assertTrue(modified) self.assert_types_identical(comp.type_signature, reduced.type_signature) self.assertGreater(count_sum_before_reduction, 0) self.assertEqual(count_sum_after_reduction, 0) self.assertGreater(count_aggregations, 0)
def test_raises_on_none(self): with self.assertRaises(TypeError): intrinsic_reductions.replace_intrinsics_with_bodies(None)