def test_returns_trees_with_one_federated_aggregate_and_one_federated_secure_sum_for_federated_secure_sum_first( self): federated_aggregate = compiler_test_utils.create_dummy_called_federated_aggregate( accumulate_parameter_name='a', merge_parameter_name='b', report_parameter_name='c') federated_secure_sum = compiler_test_utils.create_dummy_called_federated_secure_sum( ) called_intrinsics = building_blocks.Struct([ federated_aggregate, federated_secure_sum, ]) comp = building_blocks.Lambda('d', tf.int32, called_intrinsics) uri = [ intrinsic_defs.FEDERATED_SECURE_SUM.uri, intrinsic_defs.FEDERATED_AGGREGATE.uri, ] before, after = transformations.force_align_and_split_by_intrinsics( comp, uri) self.assertIsInstance(before, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri)) self.assertIsInstance(after, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
def test_handles_federated_broadcasts_nested_in_tuple(self): first_broadcast = compiler_test_utils.create_dummy_called_federated_broadcast( ) packed_broadcast = building_blocks.Struct([ building_blocks.Data( 'a', computation_types.FederatedType( computation_types.TensorType(tf.int32), placements.SERVER)), first_broadcast ]) sel = building_blocks.Selection(packed_broadcast, index=0) second_broadcast = building_block_factory.create_federated_broadcast( sel) result, _ = compiler_transformations.transform_to_call_dominant( second_broadcast) comp = building_blocks.Lambda('a', tf.int32, result) uri = [intrinsic_defs.FEDERATED_BROADCAST.uri] before, after = transformations.force_align_and_split_by_intrinsics( comp, uri) self.assertIsInstance(before, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri)) self.assertIsInstance(after, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
def test_cannot_split_on_chained_intrinsic(self): int_type = computation_types.TensorType(tf.int32) client_int_type = computation_types.at_clients(int_type) int_ref = lambda name: building_blocks.Reference(name, int_type) client_int_ref = ( lambda name: building_blocks.Reference(name, client_int_type)) body = building_blocks.Block([ ('a', building_block_factory.create_federated_map( building_blocks.Lambda('p1', int_type, int_ref('p1')), client_int_ref('param'))), ('b', building_block_factory.create_federated_map( building_blocks.Lambda('p2', int_type, int_ref('p2')), client_int_ref('a'))), ], client_int_ref('b')) comp = building_blocks.Lambda('param', int_type, body) with self.assertRaises(transformations._NonAlignableAlongIntrinsicError): transformations.force_align_and_split_by_intrinsics( comp, [building_block_factory.create_null_federated_map()])
def test_returns_trees_with_one_federated_secure_sum(self): federated_secure_sum = compiler_test_utils.create_whimsy_called_federated_secure_sum( ) called_intrinsics = building_blocks.Struct([federated_secure_sum]) comp = building_blocks.Lambda('a', tf.int32, called_intrinsics) uri = [intrinsic_defs.FEDERATED_SECURE_SUM.uri] before, after = transformations.force_align_and_split_by_intrinsics( comp, uri) self.assertIsInstance(before, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri)) self.assertIsInstance(after, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
def test_returns_trees_with_one_federated_broadcast(self): federated_broadcast = compiler_test_utils.create_dummy_called_federated_broadcast( ) called_intrinsics = building_blocks.Tuple([federated_broadcast]) comp = building_blocks.Lambda('a', tf.int32, called_intrinsics) uri = [intrinsic_defs.FEDERATED_BROADCAST.uri] before, after = transformations.force_align_and_split_by_intrinsics( comp, uri) self.assertIsInstance(before, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri)) self.assertIsInstance(after, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
def _split_ast_on_broadcast(bb): """Splits an AST on the `broadcast` intrinsic. Args: bb: An AST of arbitrary shape, potentially containing a broadcast. Returns: Two ASTs, the first of which maps comp's input to the argument of broadcast, and the second of which maps comp's input and broadcast's output to comp's output. """ before, after = transformations.force_align_and_split_by_intrinsics( bb, [building_block_factory.create_null_federated_broadcast()]) return _untuple_broadcast_only_before_after(before, after)
def test_returns_tree(self): ip = get_iterative_process_for_sum_example_with_no_federated_aggregate() next_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_aggregate( next_tree) before_federated_secure_sum, after_federated_secure_sum = ( transformations.force_align_and_split_by_intrinsics( next_tree, [intrinsic_defs.FEDERATED_SECURE_SUM.uri])) self.assertIsInstance(before_aggregate, building_blocks.Lambda) self.assertIsInstance(before_aggregate.result, building_blocks.Tuple) self.assertLen(before_aggregate.result, 2) # pyformat: disable self.assertEqual( before_aggregate.result[0].formatted_representation(), '<\n' ' federated_value_at_clients(<>),\n' ' <>,\n' ' (_var1 -> <>),\n' ' (_var2 -> <>),\n' ' (_var3 -> <>)\n' '>' ) # pyformat: enable self.assertEqual( before_aggregate.result[1].formatted_representation(), before_federated_secure_sum.result.formatted_representation()) self.assertIsInstance(after_aggregate, building_blocks.Lambda) self.assertIsInstance(after_aggregate.result, building_blocks.Call) actual_tree, _ = tree_transformations.uniquify_reference_names( after_aggregate.result.function) expected_tree, _ = tree_transformations.uniquify_reference_names( after_federated_secure_sum) self.assertEqual(actual_tree.formatted_representation(), expected_tree.formatted_representation()) # pyformat: disable self.assertEqual( after_aggregate.result.argument.formatted_representation(), '<\n' ' _var4[0],\n' ' _var4[1][1]\n' '>' )
def assert_splits_on(self, comp, calls): """Asserts that `force_align_and_split_by_intrinsics` removes intrinsics.""" if not isinstance(calls, list): calls = [calls] uris = [call.function.uri for call in calls] before, after = transformations.force_align_and_split_by_intrinsics( comp, calls) # Ensure that the resulting computations no longer contain the split # intrinsics. self.assertFalse(tree_analysis.contains_called_intrinsic(before, uris)) self.assertFalse(tree_analysis.contains_called_intrinsic(after, uris)) # Removal isn't interesting to test for if it wasn't there to begin with. self.assertTrue(tree_analysis.contains_called_intrinsic(comp, uris)) self.assert_types_equivalent(comp.parameter_type, before.parameter_type) # THere must be one parameter for each intrinsic in `calls`. before.type_signature.result.check_struct() self.assertLen(before.type_signature.result, len(calls)) # Check that `after`'s parameter is a structure like: # { # 'original_arg': comp.parameter_type, # 'intrinsic_results': [...], # } after.parameter_type.check_struct() self.assertLen(after.parameter_type, 2) self.assert_types_equivalent(comp.parameter_type, after.parameter_type.original_arg) # There must be one result for each intrinsic in `calls`. self.assertLen(after.parameter_type.intrinsic_results, len(calls)) # Check that each pair of (param, result) is a valid type substitution # for the intrinsic in question. for i in range(len(calls)): concrete_signature = computation_types.FunctionType( before.type_signature.result[i], after.parameter_type.intrinsic_results[i]) abstract_signature = calls[i].function.intrinsic_def().type_signature # `force_align_and_split_by_intrinsics` loses all-equal data due to # zipping and unzipping. This is okay because the resulting computations # are not used together directly, but are compiled into unplaced TF code. abstract_signature = _remove_client_all_equals_from_type( abstract_signature) concrete_signature = _remove_client_all_equals_from_type( concrete_signature) type_analysis.check_concrete_instance_of(concrete_signature, abstract_signature)
def _split_ast_on_aggregate(bb): """Splits an AST on reduced aggregation intrinsics. Args: bb: An AST containing `federated_aggregate` or `federated_secure_sum_bitwidth` aggregations. Returns: Two ASTs, the first of which maps comp's input to the arguments to `federated_aggregate` and `federated_secure_sum_bitwidth`, and the second of which maps comp's input and the output of `federated_aggregate` and `federated_secure_sum_bitwidth` to comp's output. """ return transformations.force_align_and_split_by_intrinsics( bb, [ building_block_factory.create_null_federated_aggregate(), building_block_factory.create_null_federated_secure_sum_bitwidth() ])
def get_canonical_form_for_iterative_process(ip): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `ip` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: ip: An instance of `tff.templates.IterativeProcess`. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(ip, iterative_process.IterativeProcess) initialize_comp = building_blocks.ComputationBuildingBlock.from_proto( ip.initialize._computation_proto) # pylint: disable=protected-access next_comp = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) # pylint: disable=protected-access _check_iterative_process_compatible_with_canonical_form( initialize_comp, next_comp) initialize_comp = _replace_intrinsics_with_bodies(initialize_comp) next_comp = _replace_intrinsics_with_bodies(next_comp) tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp) tree_analysis.check_contains_only_reducible_intrinsics(next_comp) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp) if tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri): before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsics( next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri])) else: before_broadcast, after_broadcast = ( _create_before_and_after_broadcast_for_no_broadcast(next_comp)) contains_federated_aggregate = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) contains_federated_secure_sum = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri) if not (contains_federated_aggregate or contains_federated_secure_sum): raise ValueError( 'Expected an `tff.templates.IterativeProcess` containing at least one ' '`federated_aggregate` or `federated_secure_sum`, found none.') if contains_federated_aggregate and contains_federated_secure_sum: before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( after_broadcast, [ intrinsic_defs.FEDERATED_AGGREGATE.uri, intrinsic_defs.FEDERATED_SECURE_SUM.uri, ])) elif contains_federated_secure_sum: assert not contains_federated_aggregate before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_aggregate( after_broadcast)) else: assert contains_federated_aggregate and not contains_federated_secure_sum before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_secure_sum( after_broadcast)) type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast, before_aggregate, after_aggregate) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp) _check_type_equal(initialize.type_signature, type_info['initialize_type']) prepare = _extract_prepare(before_broadcast) _check_type_equal(prepare.type_signature, type_info['prepare_type']) work = _extract_work(before_aggregate) _check_type_equal(work.type_signature, type_info['work_type']) zero, accumulate, merge, report = _extract_federated_aggregate_functions( before_aggregate) _check_type_equal(zero.type_signature, type_info['zero_type']) _check_type_equal(accumulate.type_signature, type_info['accumulate_type']) _check_type_equal(merge.type_signature, type_info['merge_type']) _check_type_equal(report.type_signature, type_info['report_type']) bitwidth = _extract_federated_secure_sum_functions(before_aggregate) _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type']) update = _extract_update(after_aggregate) _check_type_equal(update.type_signature, type_info['update_type']) return canonical_form.CanonicalForm( computation_wrapper_instances.building_block_to_computation(initialize), computation_wrapper_instances.building_block_to_computation(prepare), computation_wrapper_instances.building_block_to_computation(work), computation_wrapper_instances.building_block_to_computation(zero), computation_wrapper_instances.building_block_to_computation(accumulate), computation_wrapper_instances.building_block_to_computation(merge), computation_wrapper_instances.building_block_to_computation(report), computation_wrapper_instances.building_block_to_computation(bitwidth), computation_wrapper_instances.building_block_to_computation(update))
def _create_before_and_after_aggregate_for_no_federated_secure_sum(tree): r"""Creates a before and after aggregate computations for the given `tree`. Lambda | Tuple | [Comp, Tuple] | [Tuple, []] | [] Lambda(x) | Call / \ Comp Tuple | [Sel(0), Sel(0)] / / Ref(x) Sel(1) / Ref(x) In the first AST, the first element returned by `Lambda`, `Comp`, is the result of the before aggregate returned by force aligning and splitting `tree` by `intrinsic_defs.FEDERATED_AGGREGATE.uri` and the second element returned by `Lambda` is an empty structure that represents the argument to the secure sum intrinsic. Therefore, the first AST has a type signature satisfying the requirements of before aggregate. In the second AST, `Comp` is the after aggregate returned by force aligning and splitting `tree` by intrinsic_defs.FEDERATED_AGGREGATE.uri; `Lambda` has a type signature satisfying the requirements of after aggregate; and the argument passed to `Comp` is a selection from the parameter of `Lambda` which intentionally drops `s4` on the floor. This function is intended to be used by `get_canonical_form_for_iterative_process` to create before and after broadcast computations for the given `tree` when there is no `intrinsic_defs.FEDERATED_SECURE_SUM` in `tree`. As a result, this function does not assert that there is no `intrinsic_defs.FEDERATED_SECURE_SUM` in `tree` and it does not assert that `tree` has the expected structure, the caller is expected to perform these checks before calling this function. Args: tree: An instance of `building_blocks.ComputationBuildingBlock`. Returns: A pair of the form `(before, after)`, where each of `before` and `after` is a `tff_framework.ComputationBuildingBlock` that represents a part of the result as specified by `transformations.force_align_and_split_by_intrinsics`. """ name_generator = building_block_factory.unique_name_generator(tree) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri])) empty_tuple = building_blocks.Struct([]) value = building_block_factory.create_federated_value(empty_tuple, placements.CLIENTS) bitwidth = empty_tuple args = building_blocks.Struct([value, bitwidth]) result = building_blocks.Struct([before_aggregate.result, args]) before_aggregate = building_blocks.Lambda(before_aggregate.parameter_name, before_aggregate.parameter_type, result) ref_name = next(name_generator) s4_type = computation_types.FederatedType([], placements.SERVER) ref_type = computation_types.StructType([ after_aggregate.parameter_type[0], computation_types.StructType([ after_aggregate.parameter_type[1], s4_type, ]), ]) ref = building_blocks.Reference(ref_name, ref_type) sel_arg = building_blocks.Selection(ref, index=0) sel = building_blocks.Selection(ref, index=1) sel_s3 = building_blocks.Selection(sel, index=0) arg = building_blocks.Struct([sel_arg, sel_s3]) call = building_blocks.Call(after_aggregate, arg) after_aggregate = building_blocks.Lambda(ref.name, ref.type_signature, call) return before_aggregate, after_aggregate
def get_canonical_form_for_iterative_process(iterative_process): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `iterative_process` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: iterative_process: An instance of `tff.utils.IterativeProcess`. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(iterative_process, computation_utils.IterativeProcess) initialize_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.initialize._computation_proto) # pylint: disable=protected-access next_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.next._computation_proto) # pylint: disable=protected-access _check_iterative_process_compatible_with_canonical_form( initialize_comp, next_comp) if len(next_comp.type_signature.result) == 2: next_comp = _create_next_with_fake_client_output(next_comp) initialize_comp = _replace_intrinsics_with_bodies(initialize_comp) next_comp = _replace_intrinsics_with_bodies(next_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp) if tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri): before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsics( next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri])) else: before_broadcast, after_broadcast = ( _create_before_and_after_broadcast_for_no_broadcast(next_comp)) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri])) type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast, before_aggregate, after_aggregate) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp) _check_type_equal(initialize.type_signature, type_info['initialize_type']) prepare = _extract_prepare(before_broadcast) _check_type_equal(prepare.type_signature, type_info['prepare_type']) work = _extract_work(before_aggregate, after_aggregate) _check_type_equal(work.type_signature, type_info['work_type']) zero, accumulate, merge, report = _extract_aggregate_functions( before_aggregate) _check_type_equal(zero.type_signature, type_info['zero_type']) _check_type_equal(accumulate.type_signature, type_info['accumulate_type']) _check_type_equal(merge.type_signature, type_info['merge_type']) _check_type_equal(report.type_signature, type_info['report_type']) update = _extract_update(after_aggregate) _check_type_equal(update.type_signature, type_info['update_type']) return canonical_form.CanonicalForm( computation_wrapper_instances.building_block_to_computation( initialize), computation_wrapper_instances.building_block_to_computation(prepare), computation_wrapper_instances.building_block_to_computation(work), computation_wrapper_instances.building_block_to_computation(zero), computation_wrapper_instances.building_block_to_computation( accumulate), computation_wrapper_instances.building_block_to_computation(merge), computation_wrapper_instances.building_block_to_computation(report), computation_wrapper_instances.building_block_to_computation(update))
def get_canonical_form_for_iterative_process(iterative_process): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `iterative_process` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: iterative_process: An instance of `tff.utils.IterativeProcess`. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(iterative_process, computation_utils.IterativeProcess) initialize_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.initialize._computation_proto) # pylint: disable=protected-access next_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.next._computation_proto) # pylint: disable=protected-access if not (isinstance(next_comp.type_signature.parameter, computation_types.NamedTupleType) and isinstance(next_comp.type_signature.result, computation_types.NamedTupleType)): raise TypeError( 'Any IterativeProcess compatible with CanonicalForm must ' 'have a `next` function which takes and returns instances ' 'of `tff.NamedTupleType`; your next function takes ' 'parameters of type {} and returns results of type {}'.format( next_comp.type_signature.parameter, next_comp.type_signature.result)) if len(next_comp.type_signature.result) == 2: next_comp = _create_next_with_fake_client_output(next_comp) initialize_comp = replace_intrinsics_with_bodies(initialize_comp) next_comp = replace_intrinsics_with_bodies(next_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp) if tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri): before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsics( next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri])) else: before_broadcast, after_broadcast = ( _create_before_and_after_broadcast_for_no_broadcast(next_comp)) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri])) init_info_packed = pack_initialize_comp_type_signature( initialize_comp.type_signature) next_info_packed = pack_next_comp_type_signature(next_comp.type_signature, init_info_packed) before_broadcast_info_packed = ( check_and_pack_before_broadcast_type_signature( before_broadcast.type_signature, next_info_packed)) before_aggregate_info_packed = ( check_and_pack_before_aggregate_type_signature( before_aggregate.type_signature, before_broadcast_info_packed)) canonical_form_types = check_and_pack_after_aggregate_type_signature( after_aggregate.type_signature, before_aggregate_info_packed) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp) if initialize.type_signature.result != canonical_form_types[ 'initialize_type'].member: raise transformations.CanonicalFormCompilationError( 'Compilation of initialize has failed. Expected to extract a ' '`building_blocks.CompiledComputation` of type {}, instead we extracted ' 'a {} of type {}.'.format(next_comp.type_signature.parameter[0], type(initialize), initialize.type_signature.result)) prepare = extract_prepare(before_broadcast, canonical_form_types) work = extract_work(before_aggregate, after_aggregate, canonical_form_types) zero_noarg_function, accumulate, merge, report = extract_aggregate_functions( before_aggregate, canonical_form_types) update = extract_update(after_aggregate, canonical_form_types) cf = canonical_form.CanonicalForm( computation_wrapper_instances.building_block_to_computation( initialize), computation_wrapper_instances.building_block_to_computation(prepare), computation_wrapper_instances.building_block_to_computation(work), computation_wrapper_instances.building_block_to_computation( zero_noarg_function), computation_wrapper_instances.building_block_to_computation( accumulate), computation_wrapper_instances.building_block_to_computation(merge), computation_wrapper_instances.building_block_to_computation(report), computation_wrapper_instances.building_block_to_computation(update)) return cf
def get_canonical_form_for_iterative_process( ip: iterative_process.IterativeProcess, grappler_config: Optional[ tf.compat.v1.ConfigProto] = _GRAPPLER_DEFAULT_CONFIG): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `ip` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: ip: An instance of `tff.templates.IterativeProcess` that is compatible with canonical form. Iterative processes are only compatible if: - `initialize_fn` returns a single federated value placed at `SERVER`. - `next` takes exactly two arguments. The first must be the state value placed at `SERVER`. - `next` returns exactly two values. grappler_config: An optional instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the TensorFlow graphs backing the resulting `tff.backends.mapreduce.CanonicalForm`. These options are combined with a set of defaults that aggressively configure Grappler. If `None`, Grappler is bypassed. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(ip, iterative_process.IterativeProcess) if grappler_config is not None: py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto) overridden_grappler_config = tf.compat.v1.ConfigProto() overridden_grappler_config.CopyFrom(_GRAPPLER_DEFAULT_CONFIG) overridden_grappler_config.MergeFrom(grappler_config) grappler_config = overridden_grappler_config initialize_comp = building_blocks.ComputationBuildingBlock.from_proto( ip.initialize._computation_proto) # pylint: disable=protected-access next_comp = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) # pylint: disable=protected-access _check_iterative_process_compatible_with_canonical_form( initialize_comp, next_comp) initialize_comp = _replace_intrinsics_with_bodies(initialize_comp) next_comp = _replace_intrinsics_with_bodies(next_comp) next_comp = _replace_lambda_body_with_call_dominant_form(next_comp) tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp) tree_analysis.check_contains_only_reducible_intrinsics(next_comp) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp) if tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri): before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsics( next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri])) else: before_broadcast, after_broadcast = ( _create_before_and_after_broadcast_for_no_broadcast(next_comp)) contains_federated_aggregate = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) contains_federated_secure_sum = tree_analysis.contains_called_intrinsic( next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri) if not (contains_federated_aggregate or contains_federated_secure_sum): raise ValueError( 'Expected an `tff.templates.IterativeProcess` containing at least one ' '`federated_aggregate` or `federated_secure_sum`, found none.') if contains_federated_aggregate and contains_federated_secure_sum: before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( after_broadcast, [ intrinsic_defs.FEDERATED_AGGREGATE.uri, intrinsic_defs.FEDERATED_SECURE_SUM.uri, ])) elif contains_federated_secure_sum: assert not contains_federated_aggregate before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_aggregate( after_broadcast)) else: assert contains_federated_aggregate and not contains_federated_secure_sum before_aggregate, after_aggregate = ( _create_before_and_after_aggregate_for_no_federated_secure_sum( after_broadcast)) type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast, before_aggregate, after_aggregate) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp, grappler_config) _check_type_equal(initialize.type_signature, type_info['initialize_type']) prepare = _extract_prepare(before_broadcast, grappler_config) _check_type_equal(prepare.type_signature, type_info['prepare_type']) work = _extract_work(before_aggregate, grappler_config) _check_type_equal(work.type_signature, type_info['work_type']) zero, accumulate, merge, report = _extract_federated_aggregate_functions( before_aggregate, grappler_config) _check_type_equal(zero.type_signature, type_info['zero_type']) _check_type_equal(accumulate.type_signature, type_info['accumulate_type']) _check_type_equal(merge.type_signature, type_info['merge_type']) _check_type_equal(report.type_signature, type_info['report_type']) bitwidth = _extract_federated_secure_sum_functions(before_aggregate, grappler_config) _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type']) update = _extract_update(after_aggregate, grappler_config) _check_type_equal(update.type_signature, type_info['update_type']) next_parameter_names = (name for ( name, _) in structure.iter_elements(ip.next.type_signature.parameter)) server_state_label, client_data_label = next_parameter_names return canonical_form.CanonicalForm( computation_wrapper_instances.building_block_to_computation( initialize), computation_wrapper_instances.building_block_to_computation(prepare), computation_wrapper_instances.building_block_to_computation(work), computation_wrapper_instances.building_block_to_computation(zero), computation_wrapper_instances.building_block_to_computation( accumulate), computation_wrapper_instances.building_block_to_computation(merge), computation_wrapper_instances.building_block_to_computation(report), computation_wrapper_instances.building_block_to_computation(bitwidth), computation_wrapper_instances.building_block_to_computation(update), server_state_label=server_state_label, client_data_label=client_data_label)
def test_returns_type_info(self): ip = get_iterative_process_for_sum_example() initialize_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.initialize._computation_proto) next_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) initialize_tree = canonical_form_utils._replace_intrinsics_with_bodies( initialize_tree) next_tree = canonical_form_utils._replace_intrinsics_with_bodies( next_tree) before_broadcast, after_broadcast = ( mapreduce_transformations.force_align_and_split_by_intrinsics( next_tree, [intrinsic_defs.FEDERATED_BROADCAST.uri])) before_aggregate, after_aggregate = ( mapreduce_transformations.force_align_and_split_by_intrinsics( after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri])) type_info = canonical_form_utils._get_type_info( initialize_tree, next_tree, before_broadcast, after_broadcast, before_aggregate, after_aggregate) actual = { label: type_signature.compact_representation() for label, type_signature in type_info.items() } # pyformat: disable expected = { 'accumulate_type': '(<<int32,int32,int32,int32,int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <int32,int32,int32,int32,int32,int32>)', 'c1_type': '{int32}@CLIENTS', 'c2_type': '<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>@CLIENTS', 'c3_type': '{<int32,<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>>}@CLIENTS', 'c4_type': '{<<int32,int32,int32,int32,int32,int32>,<>>}@CLIENTS', 'c5_type': '{<int32,int32,int32,int32,int32,int32>}@CLIENTS', 'c6_type': '{<>}@CLIENTS', 'initialize_type': '( -> <int32,int32>)', 'merge_type': '(<<int32,int32,int32,int32,int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <int32,int32,int32,int32,int32,int32>)', 'prepare_type': '(<int32,int32> -> <<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>)', 'report_type': '(<int32,int32,int32,int32,int32,int32> -> <int32,int32,int32,int32,int32,int32>)', 's1_type': '<int32,int32>@SERVER', 's2_type': '<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>@SERVER', 's3_type': '<int32,int32,int32,int32,int32,int32>@SERVER', 's4_type': '<<int32,int32>,<int32,int32,int32,int32,int32,int32>>@SERVER', 's5_type': '<<int32,int32>,<>>@SERVER', 's6_type': '<int32,int32>@SERVER', 's7_type': '<>@SERVER', 'update_type': '(<<int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <<int32,int32>,<>>)', 'work_type': '(<int32,<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>> -> <<int32,int32,int32,int32,int32,int32>,<>>)', 'zero_type': '( -> <int32,int32,int32,int32,int32,int32>)' } # pyformat: enable self.assertEqual(actual, expected)
def test_returns_type_info(self): ip = get_iterative_process_for_sum_example() initialize_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.initialize._computation_proto) next_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) initialize_tree = canonical_form_utils._replace_intrinsics_with_bodies( initialize_tree) next_tree = canonical_form_utils._replace_intrinsics_with_bodies( next_tree) before_broadcast, after_broadcast = ( mapreduce_transformations.force_align_and_split_by_intrinsics( next_tree, [intrinsic_defs.FEDERATED_BROADCAST.uri])) before_aggregate, after_aggregate = ( mapreduce_transformations.force_align_and_split_by_intrinsics( after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri])) type_info = canonical_form_utils._get_type_info( initialize_tree, before_broadcast, after_broadcast, before_aggregate, after_aggregate) actual = collections.OrderedDict([ (label, type_signature.compact_representation()) for label, type_signature in type_info.items() ]) # Note: THE CONTENTS OF THIS DICTIONARY IS NOT IMPORTANT. The purpose of # this test is not to assert that this value returned by # `canonical_form_utils._get_type_info`, but instead to act as a signal when # refactoring the code involved in compiling an `tff.utils.IterativeProcess` # into a `tff.backends.mapreduce.CanonicalForm`. # pyformat: disable expected = collections.OrderedDict( initialize_type='( -> <int32,int32>)', s1_type='<int32,int32>@SERVER', c1_type='{int32}@CLIENTS', s2_type= '<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>@SERVER', prepare_type= '(<int32,int32> -> <<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>)', c2_type= '<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>@CLIENTS', c3_type= '{<int32,<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>>}@CLIENTS', c4_type='{<<int32,int32,int32,int32,int32,int32>,<>>}@CLIENTS', work_type= '(<int32,<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>> -> <<int32,int32,int32,int32,int32,int32>,<>>)', c5_type='{<int32,int32,int32,int32,int32,int32>}@CLIENTS', c6_type='{<>}@CLIENTS', zero_type='( -> <int32,int32,int32,int32,int32,int32>)', accumulate_type= '(<<int32,int32,int32,int32,int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <int32,int32,int32,int32,int32,int32>)', merge_type= '(<<int32,int32,int32,int32,int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <int32,int32,int32,int32,int32,int32>)', report_type= '(<int32,int32,int32,int32,int32,int32> -> <int32,int32,int32,int32,int32,int32>)', s3_type='<int32,int32,int32,int32,int32,int32>@SERVER', s4_type= '<<int32,int32>,<int32,int32,int32,int32,int32,int32>>@SERVER', s5_type='<<int32,int32>,<>>@SERVER', update_type= '(<<int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <<int32,int32>,<>>)', s6_type='<int32,int32>@SERVER', s7_type='<>@SERVER', ) # pyformat: enable self.assertEqual(actual, expected)
def test_returns_type_info_for_sum_example(self): ip = get_iterative_process_for_sum_example() initialize_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.initialize._computation_proto) next_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) initialize_tree = canonical_form_utils._replace_intrinsics_with_bodies( initialize_tree) next_tree = canonical_form_utils._replace_intrinsics_with_bodies(next_tree) before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsics( next_tree, [intrinsic_defs.FEDERATED_BROADCAST.uri])) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( after_broadcast, [ intrinsic_defs.FEDERATED_AGGREGATE.uri, intrinsic_defs.FEDERATED_SECURE_SUM.uri, ])) type_info = canonical_form_utils._get_type_info(initialize_tree, before_broadcast, after_broadcast, before_aggregate, after_aggregate) actual = collections.OrderedDict([ (label, type_signature.compact_representation()) for label, type_signature in type_info.items() ]) # Note: THE CONTENTS OF THIS DICTIONARY IS NOT IMPORTANT. The purpose of # this test is not to assert that this value returned by # `canonical_form_utils._get_type_info`, but instead to act as a signal when # refactoring the code involved in compiling an # `tff.templates.IterativeProcess` into a # `tff.backends.mapreduce.CanonicalForm`. If you are sure this needs to be # updated, one recommendation is to print 'k=\'v\',' while iterating over # the k-v pairs of the ordereddict. # pyformat: disable expected = collections.OrderedDict( initialize_type='( -> <int32,int32>)', s1_type='<int32,int32>@SERVER', c1_type='{int32}@CLIENTS', prepare_type='(<int32,int32> -> <<int32,int32>>)', s2_type='<<int32,int32>>@SERVER', c2_type='<<int32,int32>>@CLIENTS', c3_type='{<int32,<<int32,int32>>>}@CLIENTS', work_type='(<int32,<<int32,int32>>> -> <<<int32>,<int32>>,<>>)', c4_type='{<<<int32>,<int32>>,<>>}@CLIENTS', c5_type='{<<int32>,<int32>>}@CLIENTS', c6_type='{<int32>}@CLIENTS', c7_type='{<int32>}@CLIENTS', c8_type='{<>}@CLIENTS', zero_type='( -> <int32>)', accumulate_type='(<<int32>,<int32>> -> <int32>)', merge_type='(<<int32>,<int32>> -> <int32>)', report_type='(<int32> -> <int32>)', s3_type='<int32>@SERVER', bitwidth_type='( -> <int32>)', s4_type='<int32>@SERVER', s5_type='<<int32>,<int32>>@SERVER', s6_type='<<int32,int32>,<<int32>,<int32>>>@SERVER', update_type='(<<int32,int32>,<<int32>,<int32>>> -> <<int32,int32>,<>>)', s7_type='<<int32,int32>,<>>@SERVER', s8_type='<int32,int32>@SERVER', s9_type='<>@SERVER', ) # pyformat: enable items = zip(actual.items(), expected.items()) for (actual_key, actual_value), (expected_key, expected_value) in items: self.assertEqual(actual_key, expected_key) self.assertEqual( actual_value, expected_value, 'The value of \'{}\' is not equal to the expected value'.format( actual_key))
def test_returns_tree(self): ip = get_iterative_process_for_sum_example_with_no_federated_secure_sum( ) next_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) next_tree = canonical_form_utils._replace_intrinsics_with_bodies( next_tree) before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_secure_sum( next_tree) before_federated_aggregate, after_federated_aggregate = ( transformations.force_align_and_split_by_intrinsics( next_tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri])) self.assertIsInstance(before_aggregate, building_blocks.Lambda) self.assertIsInstance(before_aggregate.result, building_blocks.Struct) self.assertLen(before_aggregate.result, 2) # trees_equal will fail if computations refer to unbound references, so we # create a new dummy computation to bind them. unbound_refs_in_before_agg_result = transformation_utils.get_map_of_unbound_references( before_aggregate.result[0])[before_aggregate.result[0]] unbound_refs_in_before_fed_agg_result = transformation_utils.get_map_of_unbound_references( before_federated_aggregate.result)[ before_federated_aggregate.result] dummy_data = building_blocks.Data('data', computation_types.AbstractType('T')) blk_binding_refs_in_before_agg = building_blocks.Block( [(name, dummy_data) for name in unbound_refs_in_before_agg_result], before_aggregate.result[0]) blk_binding_refs_in_before_fed_agg = building_blocks.Block( [(name, dummy_data) for name in unbound_refs_in_before_fed_agg_result], before_federated_aggregate.result) self.assertTrue( tree_analysis.trees_equal(blk_binding_refs_in_before_agg, blk_binding_refs_in_before_fed_agg)) # pyformat: disable self.assertEqual( before_aggregate.result[1].formatted_representation(), '<\n' ' federated_value_at_clients(<>),\n' ' <>\n' '>') # pyformat: enable self.assertIsInstance(after_aggregate, building_blocks.Lambda) self.assertIsInstance(after_aggregate.result, building_blocks.Call) self.assertTrue( tree_analysis.trees_equal(after_aggregate.result.function, after_federated_aggregate)) # pyformat: disable self.assertEqual( after_aggregate.result.argument.formatted_representation(), '<\n' ' _var1[0],\n' ' _var1[1][0]\n' '>')