def test_removes_federated_types_under_function(self): int_type = tf.int32 server_int_type = computation_types.at_server(int_type) int_ref = building_blocks.Reference('x', int_type) int_id = building_blocks.Lambda('x', int_type, int_ref) fed_ref = building_blocks.Reference('x', server_int_type) applied_id = building_block_factory.create_federated_map_or_apply( int_id, fed_ref) before = building_block_factory.create_federated_map_or_apply( int_id, applied_id) after, modified = tree_transformations.strip_placement(before) self.assertTrue(modified) self.assert_has_no_intrinsics_nor_federated_types(after)
def test_strip_placement_removes_federated_maps(self): int_type = computation_types.TensorType(tf.int32) clients_int_type = computation_types.at_clients(int_type) int_ref = building_blocks.Reference('x', int_type) int_id = building_blocks.Lambda('x', int_type, int_ref) fed_ref = building_blocks.Reference('x', clients_int_type) applied_id = building_block_factory.create_federated_map_or_apply( int_id, fed_ref) before = building_block_factory.create_federated_map_or_apply( int_id, applied_id) after, modified = tree_transformations.strip_placement(before) self.assertTrue(modified) self.assert_has_no_intrinsics_nor_federated_types(after) type_test_utils.assert_types_identical(before.type_signature, clients_int_type) type_test_utils.assert_types_identical(after.type_signature, int_type) self.assertEqual( before.compact_representation(), 'federated_map(<(x -> x),federated_map(<(x -> x),x>)>)') self.assertEqual(after.compact_representation(), '(x -> x)((x -> x)(x))')
def test_reduces_federated_apply_to_equivalent_function(self): lam = building_blocks.Lambda('x', tf.int32, building_blocks.Reference('x', tf.int32)) arg = building_blocks.Reference( 'arg', computation_types.FederatedType(tf.int32, placements.CLIENTS)) mapped_fn = building_block_factory.create_federated_map_or_apply(lam, arg) extracted_tf = mapreduce_transformations.consolidate_and_extract_local_processing( mapped_fn) self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation) executable_tf = computation_wrapper_instances.building_block_to_computation( extracted_tf) executable_lam = computation_wrapper_instances.building_block_to_computation( lam) for k in range(10): self.assertEqual(executable_tf(k), executable_lam(k))
def _construct_selection_from_federated_tuple( federated_tuple: building_blocks.ComputationBuildingBlock, index: int, name_generator) -> building_blocks.ComputationBuildingBlock: """Selects the index `selected_index` from `federated_tuple`.""" federated_tuple.type_signature.check_federated() member_type = federated_tuple.type_signature.member member_type.check_struct() param_name = next(name_generator) selecting_function = building_blocks.Lambda( param_name, member_type, building_blocks.Selection( building_blocks.Reference(param_name, member_type), index=index, )) return building_block_factory.create_federated_map_or_apply( selecting_function, federated_tuple)
def test_reduces_federated_apply_to_equivalent_function(self): lam = building_blocks.Lambda('x', tf.int32, building_blocks.Reference('x', tf.int32)) arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) arg = building_blocks.Reference('arg', arg_type) map_block = building_block_factory.create_federated_map_or_apply( lam, arg) mapping_fn = building_blocks.Lambda('arg', arg_type, map_block) extracted_tf = compiler.consolidate_and_extract_local_processing( mapping_fn, DEFAULT_GRAPPLER_CONFIG) self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation) executable_tf = computation_impl.ConcreteComputation.from_building_block( extracted_tf) executable_lam = computation_impl.ConcreteComputation.from_building_block( lam) for k in range(10): self.assertEqual(executable_tf(k), executable_lam(k))
def _extract_update(after_aggregate): """Extracts `update` from `after_aggregate`. This function is intended to be used by `get_canonical_form_for_iterative_process` only. As a result, this function does not assert that `after_aggregate` has the expected structure, the caller is expected to perform these checks before calling this function. Args: after_aggregate: The second result of splitting `after_broadcast` on aggregate intrinsics. Returns: `update` as specified by `canonical_form.CanonicalForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we extract an AST of the wrong type. """ s7_elements_in_after_aggregate_result = [0, 1] s7_output_extracted = transformations.select_output_from_lambda( after_aggregate, s7_elements_in_after_aggregate_result) s7_output_zipped = building_blocks.Lambda( s7_output_extracted.parameter_name, s7_output_extracted.parameter_type, building_block_factory.create_federated_zip(s7_output_extracted.result)) s6_elements_in_after_aggregate_parameter = [[0, 0, 0], [1, 0], [1, 1]] s6_to_s7_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( s7_output_zipped, s6_elements_in_after_aggregate_parameter).result.function) # TODO(b/148942011): The transformation # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting # from nested structures, therefore we need to pack the type signature # `<s1, s3, s4>` as `<s1, <s3, s4>>`. name_generator = building_block_factory.unique_name_generator( s6_to_s7_computation) pack_ref_name = next(name_generator) pack_ref_type = computation_types.StructType([ s6_to_s7_computation.parameter_type.member[0], computation_types.StructType([ s6_to_s7_computation.parameter_type.member[1], s6_to_s7_computation.parameter_type.member[2], ]), ]) pack_ref = building_blocks.Reference(pack_ref_name, pack_ref_type) sel_s1 = building_blocks.Selection(pack_ref, index=0) sel = building_blocks.Selection(pack_ref, index=1) sel_s3 = building_blocks.Selection(sel, index=0) sel_s4 = building_blocks.Selection(sel, index=1) result = building_blocks.Struct([sel_s1, sel_s3, sel_s4]) pack_fn = building_blocks.Lambda(pack_ref.name, pack_ref.type_signature, result) ref_name = next(name_generator) ref_type = computation_types.FederatedType(pack_ref_type, placements.SERVER) ref = building_blocks.Reference(ref_name, ref_type) unpacked_args = building_block_factory.create_federated_map_or_apply( pack_fn, ref) call = building_blocks.Call(s6_to_s7_computation, unpacked_args) fn = building_blocks.Lambda(ref.name, ref.type_signature, call) return transformations.consolidate_and_extract_local_processing(fn)
def _extract_update(after_aggregate, grappler_config): """Extracts `update` from `after_aggregate`. This function is intended to be used by `get_map_reduce_form_for_iterative_process` only. As a result, this function does not assert that `after_aggregate` has the expected structure, the caller is expected to perform these checks before calling this function. Args: after_aggregate: The second result of splitting `after_broadcast` on aggregate intrinsics. grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization. Returns: `update` as specified by `forms.MapReduceForm`, an instance of `building_blocks.CompiledComputation`. Raises: compiler.MapReduceFormCompilationError: If we extract an AST of the wrong type. """ after_aggregate_zipped = building_blocks.Lambda( after_aggregate.parameter_name, after_aggregate.parameter_type, building_block_factory.create_federated_zip(after_aggregate.result)) # `create_federated_zip` doesn't have unique reference names, but we need # them for `as_function_of_some_federated_subparameters`. after_aggregate_zipped, _ = tree_transformations.uniquify_reference_names( after_aggregate_zipped) server_state_index = ('original_arg', 'original_arg', 0) aggregate_result_index = ('intrinsic_results', 'federated_aggregate_result') secure_sum_bitwidth_result_index = ('intrinsic_results', 'federated_secure_sum_bitwidth_result') secure_sum_result_index = ('intrinsic_results', 'federated_secure_sum_result') secure_modular_sum_result_index = ('intrinsic_results', 'federated_secure_modular_sum_result') update_with_flat_inputs = _as_function_of_some_federated_subparameters( after_aggregate_zipped, ( server_state_index, aggregate_result_index, secure_sum_bitwidth_result_index, secure_sum_result_index, secure_modular_sum_result_index, )) # TODO(b/148942011): The transformation # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting # from nested structures, therefore we need to transform the input from # <server_state, <aggregation_results...>> into # <server_state, aggregation_results...> # unpack = <v, <...>> -> <v, ...> name_generator = building_block_factory.unique_name_generator( update_with_flat_inputs) unpack_param_name = next(name_generator) original_param_type = update_with_flat_inputs.parameter_type.member unpack_param_type = computation_types.StructType([ original_param_type[0], computation_types.StructType(original_param_type[1:]), ]) unpack_param_ref = building_blocks.Reference(unpack_param_name, unpack_param_type) select = lambda bb, i: building_blocks.Selection(bb, index=i) unpack = building_blocks.Lambda( unpack_param_name, unpack_param_type, building_blocks.Struct([select(unpack_param_ref, 0)] + [ select(select(unpack_param_ref, 1), i) for i in range(len(original_param_type) - 1) ])) # update = v -> update_with_flat_inputs(federated_map(unpack, v)) param_name = next(name_generator) param_type = computation_types.at_server(unpack_param_type) param_ref = building_blocks.Reference(param_name, param_type) update = building_blocks.Lambda( param_name, param_type, building_blocks.Call( update_with_flat_inputs, building_block_factory.create_federated_map_or_apply( unpack, param_ref))) return compiler.consolidate_and_extract_local_processing( update, grappler_config)
def force_align_and_split_by_intrinsics( comp: building_blocks.Lambda, intrinsic_defaults: List[building_blocks.Call], ) -> Tuple[building_blocks.Lambda, building_blocks.Lambda]: """Divides `comp` into before-and-after of calls to one ore more intrinsics. The input computation `comp` must have the following properties: 1. The computation `comp` is completely self-contained, i.e., there are no references to arguments introduced in a scope external to `comp`. 2. `comp`'s return value must not contain uncalled lambdas. 3. None of the calls to intrinsics in `intrinsic_defaults` may be within a lambda passed to another external function (intrinsic or compiled computation). 4. No argument passed to an intrinsic in `intrinsic_defaults` may be dependent on the result of a call to an intrinsic in `intrinsic_uris_and_defaults`. 5. All intrinsics in `intrinsic_defaults` must have "merge-able" arguments. Structs will be merged element-wise, federated values will be zipped, and functions will be composed: `f = lambda f1_arg, f2_arg: (f1(f1_arg), f2(f2_arg))` 6. All intrinsics in `intrinsic_defaults` must return a single federated value whose member is the merged result of any merged calls, i.e.: `f(merged_arg).member = (f1(f1_arg).member, f2(f2_arg).member)` Under these conditions, (and assuming `comp` is a computation with non-`None` argument), this function will return two `building_blocks.Lambda`s `before` and `after` such that `comp` is semantically equivalent to the following expression*: ``` (arg -> (let x=before(arg), y=intrinsic1(x[0]), z=intrinsic2(x[1]), ... in after(<arg, <y,z,...>>))) ``` If `comp` is a no-arg computation, the returned computations will be equivalent (in the same sense as above) to: ``` ( -> (let x=before(), y=intrinsic1(x[0]), z=intrinsic2(x[1]), ... in after(<y,z,...>))) ``` *Note that these expressions may not be entirely equivalent under nondeterminism since there is no way in this case to handle computations in which `before` creates a random variable that is then used in `after`, since the only way for state to pass from `before` to `after` is for it to travel through one of the intrinsics. In this expression, there is only a single call to `intrinsic` that results from consolidating all occurrences of this intrinsic in the original `comp`. All logic in `comp` that produced inputs to any these intrinsic calls is now consolidated and jointly encapsulated in `before`, which produces a combined argument to all the original calls. All the remaining logic in `comp`, including that which consumed the outputs of the intrinsic calls, must have been encapsulated into `after`. If the original computation `comp` had type `(T -> U)`, then `before` and `after` would be `(T -> X)` and `(<T,Y> -> U)`, respectively, where `X` is the type of the argument to the single combined intrinsic call above. Note that `after` takes the output of the call to the intrinsic as well as the original argument to `comp`, as it may be dependent on both. Args: comp: The instance of `building_blocks.Lambda` that serves as the input to this transformation, as described above. intrinsic_defaults: A list of intrinsics with which to split the computation, provided as a list of `Call`s to insert if no intrinsic with a matching URI is found. Intrinsics in this list will be merged, and `comp` will be split across them. Returns: A pair of the form `(before, after)`, where each of `before` and `after` is a `building_blocks.ComputationBuildingBlock` instance that represents a part of the result as specified above. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(intrinsic_defaults, list) comp_repr = comp.compact_representation() # Flatten `comp` to call-dominant form so that we're working with just a # linear list of intrinsic calls with no indirection via tupling, selection, # blocks, called lambdas, or references. comp = to_call_dominant(comp) # CDF can potentially return blocks if there are variables not dependent on # the top-level parameter. We normalize these away. if not comp.is_lambda(): comp.check_block() comp.result.check_lambda() if comp.result.result.is_block(): additional_locals = comp.result.result.locals result = comp.result.result.result else: additional_locals = [] result = comp.result.result # Note: without uniqueness, a local in `comp.locals` could potentially # shadow `comp.result.parameter_name`. However, `to_call_dominant` # above ensure that names are unique, as it ends in a call to # `uniquify_reference_names`. comp = building_blocks.Lambda( comp.result.parameter_name, comp.result.parameter_type, building_blocks.Block(comp.locals + additional_locals, result)) comp.check_lambda() # Simple computations with no intrinsic calls won't have a block. # Normalize these as well. if not comp.result.is_block(): comp = building_blocks.Lambda(comp.parameter_name, comp.parameter_type, building_blocks.Block([], comp.result)) comp.result.check_block() name_generator = building_block_factory.unique_name_generator(comp) intrinsic_uris = set(call.function.uri for call in intrinsic_defaults) deps = _compute_intrinsic_dependencies(intrinsic_uris, comp.parameter_name, comp.result.locals, comp_repr) merged_intrinsics = _compute_merged_intrinsics(intrinsic_defaults, deps.uri_to_locals, name_generator) # Note: the outputs are labeled as `{uri}_param for convenience, e.g. # `federated_secure_sum_param: ...`. before = building_blocks.Lambda( comp.parameter_name, comp.parameter_type, building_blocks.Block( deps.locals_not_dependent_on_intrinsics, building_blocks.Struct([(f'{merged.uri}_param', merged.args) for merged in merged_intrinsics]))) after_param_name = next(name_generator) if comp.parameter_type is not None: # TODO(b/147499373): If None-arguments were uniformly represented as empty # tuples, we would be able to avoid this (and related) ugly casing. after_param_type = computation_types.StructType([ ('original_arg', comp.parameter_type), ('intrinsic_results', computation_types.StructType([(f'{merged.uri}_result', merged.return_type) for merged in merged_intrinsics])), ]) else: after_param_type = computation_types.StructType([ ('intrinsic_results', computation_types.StructType([(f'{merged.uri}_result', merged.return_type) for merged in merged_intrinsics])), ]) after_param_ref = building_blocks.Reference(after_param_name, after_param_type) if comp.parameter_type is not None: original_arg_bindings = [ (comp.parameter_name, building_blocks.Selection(after_param_ref, name='original_arg')) ] else: original_arg_bindings = [] unzip_bindings = [] for merged in merged_intrinsics: if merged.unpack_to_locals: intrinsic_result = building_blocks.Selection( building_blocks.Selection(after_param_ref, name='intrinsic_results'), name=f'{merged.uri}_result') select_param_type = intrinsic_result.type_signature.member for i, binding_name in enumerate(merged.unpack_to_locals): select_param_name = next(name_generator) select_param_ref = building_blocks.Reference( select_param_name, select_param_type) selected = building_block_factory.create_federated_map_or_apply( building_blocks.Lambda( select_param_name, select_param_type, building_blocks.Selection(select_param_ref, index=i)), intrinsic_result) unzip_bindings.append((binding_name, selected)) after = building_blocks.Lambda( after_param_name, after_param_type, building_blocks.Block( original_arg_bindings + # Note that we must duplicate `locals_not_dependent_on_intrinsics` # across both the `before` and `after` computations since both can # rely on them, and there's no way to plumb results from `before` # through to `after` except via one of the intrinsics being split # upon. In MapReduceForm, this limitation is caused by the fact that # `prepare` has no output which serves as an input to `report`. deps.locals_not_dependent_on_intrinsics + unzip_bindings + deps.locals_dependent_on_intrinsics, comp.result.result)) try: tree_analysis.check_has_unique_names(before) tree_analysis.check_has_unique_names(after) except tree_analysis.NonuniqueNameError as e: raise ValueError( f'nonunique names in result of splitting\n{comp}') from e return before, after