def test_ok_on_nested_lambdas_with_different_variable_name(self): ref_to_x = building_blocks.Reference('x', tf.int32) lambda_1 = building_blocks.Lambda('x', tf.int32, ref_to_x) lambda_2 = building_blocks.Lambda('y', tf.int32, lambda_1) tree_analysis.check_has_unique_names(lambda_2)
def _create_before_and_after_aggregate_for_no_federated_secure_sum(tree): r"""Creates a before and after aggregate computations for the given `tree`. Lambda | Tuple | [Comp, Tuple] | [Tuple, []] | [] Lambda(x) | Call / \ Comp Tuple | [Sel(0), Sel(0)] / / Ref(x) Sel(1) / Ref(x) In the first AST, the first element returned by `Lambda`, `Comp`, is the result of the before aggregate returned by force aligning and splitting `tree` by `intrinsic_defs.FEDERATED_AGGREGATE.uri` and the second element returned by `Lambda` is an empty structure that represents the argument to the secure sum intrinsic. Therefore, the first AST has a type signature satisfying the requirements of before aggregate. In the second AST, `Comp` is the after aggregate returned by force aligning and splitting `tree` by intrinsic_defs.FEDERATED_AGGREGATE.uri; `Lambda` has a type signature satisfying the requirements of after aggregate; and the argument passed to `Comp` is a selection from the parameter of `Lambda` which intentionally drops `s4` on the floor. This function is intended to be used by `get_canonical_form_for_iterative_process` to create before and after broadcast computations for the given `tree` when there is no `intrinsic_defs.FEDERATED_SECURE_SUM` in `tree`. As a result, this function does not assert that there is no `intrinsic_defs.FEDERATED_SECURE_SUM` in `tree` and it does not assert that `tree` has the expected structure, the caller is expected to perform these checks before calling this function. Args: tree: An instance of `building_blocks.ComputationBuildingBlock`. Returns: A pair of the form `(before, after)`, where each of `before` and `after` is a `tff_framework.ComputationBuildingBlock` that represents a part of the result as specified by `transformations.force_align_and_split_by_intrinsics`. """ name_generator = building_block_factory.unique_name_generator(tree) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri])) empty_tuple = building_blocks.Struct([]) value = building_block_factory.create_federated_value(empty_tuple, placements.CLIENTS) bitwidth = empty_tuple args = building_blocks.Struct([value, bitwidth]) result = building_blocks.Struct([before_aggregate.result, args]) before_aggregate = building_blocks.Lambda(before_aggregate.parameter_name, before_aggregate.parameter_type, result) ref_name = next(name_generator) s4_type = computation_types.FederatedType([], placements.SERVER) ref_type = computation_types.StructType([ after_aggregate.parameter_type[0], computation_types.StructType([ after_aggregate.parameter_type[1], s4_type, ]), ]) ref = building_blocks.Reference(ref_name, ref_type) sel_arg = building_blocks.Selection(ref, index=0) sel = building_blocks.Selection(ref, index=1) sel_s3 = building_blocks.Selection(sel, index=0) arg = building_blocks.Struct([sel_arg, sel_s3]) call = building_blocks.Call(after_aggregate, arg) after_aggregate = building_blocks.Lambda(ref.name, ref.type_signature, call) return before_aggregate, after_aggregate
def test_raises_type_error_with_int_excluding(self): ref = building_blocks.Reference('a', tf.int32) fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) with self.assertRaises(TypeError): tree_analysis.contains_no_unbound_references(fn, 1)
def get_canonical_form_for_iterative_process(iterative_process): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `iterative_process` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: iterative_process: An instance of `tff.utils.IterativeProcess`. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(iterative_process, computation_utils.IterativeProcess) initialize_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.initialize._computation_proto) # pylint: disable=protected-access next_comp = building_blocks.ComputationBuildingBlock.from_proto( iterative_process.next._computation_proto) # pylint: disable=protected-access if not (isinstance(next_comp.type_signature.parameter, computation_types.NamedTupleType) and isinstance(next_comp.type_signature.result, computation_types.NamedTupleType)): raise TypeError( 'Any IterativeProcess compatible with CanonicalForm must ' 'have a `next` function which takes and returns instances ' 'of `tff.NamedTupleType`; your next function takes ' 'parameters of type {} and returns results of type {}'.format( next_comp.type_signature.parameter, next_comp.type_signature.result)) if len(next_comp.type_signature.result) == 2: next_result = next_comp.result if isinstance(next_result, building_blocks.Tuple): dummy_clients_metrics_appended = building_blocks.Tuple([ next_result[0], next_result[1], intrinsics.federated_value([], placements.CLIENTS)._comp # pylint: disable=protected-access ]) else: dummy_clients_metrics_appended = building_blocks.Tuple([ building_blocks.Selection(next_result, index=0), building_blocks.Selection(next_result, index=1), intrinsics.federated_value([], placements.CLIENTS)._comp # pylint: disable=protected-access ]) next_comp = building_blocks.Lambda(next_comp.parameter_name, next_comp.parameter_type, dummy_clients_metrics_appended) initialize_comp = replace_intrinsics_with_bodies(initialize_comp) next_comp = replace_intrinsics_with_bodies(next_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp) tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp) tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp) before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsic( next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri)) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsic( after_broadcast, intrinsic_defs.FEDERATED_AGGREGATE.uri)) init_info_packed = pack_initialize_comp_type_signature( initialize_comp.type_signature) next_info_packed = pack_next_comp_type_signature(next_comp.type_signature, init_info_packed) before_broadcast_info_packed = ( check_and_pack_before_broadcast_type_signature( before_broadcast.type_signature, next_info_packed)) before_aggregate_info_packed = ( check_and_pack_before_aggregate_type_signature( before_aggregate.type_signature, before_broadcast_info_packed)) canonical_form_types = check_and_pack_after_aggregate_type_signature( after_aggregate.type_signature, before_aggregate_info_packed) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp) if not (isinstance(initialize, building_blocks.CompiledComputation) and initialize.type_signature.result == canonical_form_types['initialize_type'].member): raise transformations.CanonicalFormCompilationError( 'Compilation of initialize has failed. Expected to extract a ' '`building_blocks.CompiledComputation` of type {}, instead we extracted ' 'a {} of type {}.'.format(next_comp.type_signature.parameter[0], type(initialize), initialize.type_signature.result)) prepare = extract_prepare(before_broadcast, canonical_form_types) work = extract_work(before_aggregate, after_aggregate, canonical_form_types) zero_noarg_function, accumulate, merge, report = extract_aggregate_functions( before_aggregate, canonical_form_types) update = extract_update(after_aggregate, canonical_form_types) cf = canonical_form.CanonicalForm( computation_wrapper_instances.building_block_to_computation(initialize), computation_wrapper_instances.building_block_to_computation(prepare), computation_wrapper_instances.building_block_to_computation(work), computation_wrapper_instances.building_block_to_computation( zero_noarg_function), computation_wrapper_instances.building_block_to_computation(accumulate), computation_wrapper_instances.building_block_to_computation(merge), computation_wrapper_instances.building_block_to_computation(report), computation_wrapper_instances.building_block_to_computation(update)) return cf
def _create_before_and_after_broadcast_for_no_broadcast(tree): r"""Creates a before and after broadcast computations for the given `tree`. This function returns the two ASTs: Lambda | Tuple | [] Lambda(x) | Call / \ Comp Sel(0) / Ref(x) The first AST is an empty structure that has a type signature satisfying the requirements of before broadcast. In the second AST, `Comp` is `tree`; `Lambda` has a type signature satisfying the requirements of after broadcast; and the argument passed to `Comp` is a selection from the parameter of `Lambda` which intentionally drops `c2` on the floor. This function is intended to be used by `get_canonical_form_for_iterative_process` to create before and after broadcast computations for the given `tree` when there is no `intrinsic_defs.FEDERATED_BROADCAST` in `tree`. As a result, this function does not assert that there is no `intrinsic_defs.FEDERATED_BROADCAST` in `tree` and it does not assert that `tree` has the expected structure, the caller is expected to perform these checks before calling this function. Args: tree: An instance of `building_blocks.ComputationBuildingBlock`. Returns: A pair of the form `(before, after)`, where each of `before` and `after` is a `tff_framework.ComputationBuildingBlock` that represents a part of the result as specified by `transformations.force_align_and_split_by_intrinsics`. """ name_generator = building_block_factory.unique_name_generator(tree) parameter_name = next(name_generator) empty_tuple = building_blocks.Struct([]) value = building_block_factory.create_federated_value(empty_tuple, placements.SERVER) before_broadcast = building_blocks.Lambda(parameter_name, tree.type_signature.parameter, value) parameter_name = next(name_generator) type_signature = computation_types.FederatedType( before_broadcast.type_signature.result.member, placements.CLIENTS) parameter_type = computation_types.StructType( [tree.type_signature.parameter, type_signature]) ref = building_blocks.Reference(parameter_name, parameter_type) arg = building_blocks.Selection(ref, index=0) call = building_blocks.Call(tree, arg) after_broadcast = building_blocks.Lambda(ref.name, ref.type_signature, call) return before_broadcast, after_broadcast
def _group_by_intrinsics_in_top_level_lambda(comp): """Groups the intrinsics in the frist block local in the result of `comp`. This transformation creates an AST by replacing the tuple of called intrinsics found as the first local in the `building_blocks.Block` returned by the top level lambda with two new computations. The first computation is a tuple of tuples of called intrinsics, representing the original tuple of called intrinscis grouped by URI. The second computation is a tuple of selection from the first computations, representing original tuple of called intrinsics. It is necessary to group intrinsics before it is possible to merge them. Args: comp: The `building_blocks.Lambda` to transform. Returns: A `building_blocks.Lamda` that returns a `building_blocks.Block`, the first local variables of the retunred `building_blocks.Block` will be a tuple of tuples of called intrinsics representing the original tuple of called intrinscis grouped by URI. Raises: ValueError: If the first local in the `building_blocks.Block` referenced by the top level lambda is not a `building_blocks.Struct` of called intrinsics. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(comp.result, building_blocks.Block) tree_analysis.check_has_unique_names(comp) name_generator = building_block_factory.unique_name_generator(comp) name, first_local = comp.result.locals[0] py_typecheck.check_type(first_local, building_blocks.Struct) for element in first_local: if not building_block_analysis.is_called_intrinsic(element): raise ValueError( 'Expected all the elements of the `building_blocks.Struct` to be ' 'called intrinsics, but found: \n{}'.format(element)) # Create collections of data describing how to pack and unpack the intrinsics # into groups by their URI. # # packed_keys is a list of unique URI ordered by occurrence in the original # tuple of called intrinsics. # packed_groups is a `collections.OrderedDict` where each key is a URI to # group by and each value is a list of intrinsics with that URI. # packed_indexes is a list of tuples where each tuple contains two indexes: # the first index in the tuple is the index of the group that the intrinsic # was packed into; the second index in the tuple is the index of the # intrinsic in that group that the intrinsic was packed into; the index of # the tuple in packed_indexes corresponds to the index of the intrinsic in # the list of intrinsics that are beging grouped. Therefore, packed_indexes # represents an implicit mapping of packed indexes, keyed by unpacked index. packed_keys = [] for called_intrinsic in first_local: uri = called_intrinsic.function.uri if uri not in packed_keys: packed_keys.append(uri) # If there are no duplicates, return early. if len(packed_keys) == len(first_local): return comp, False packed_groups = collections.OrderedDict([(x, []) for x in packed_keys]) packed_indexes = [] for called_intrinsic in first_local: packed_group = packed_groups[called_intrinsic.function.uri] packed_group.append(called_intrinsic) packed_indexes.append(( packed_keys.index(called_intrinsic.function.uri), len(packed_group) - 1, )) packed_elements = [] for called_intrinsics in packed_groups.values(): if len(called_intrinsics) > 1: element = building_blocks.Struct(called_intrinsics) else: element = called_intrinsics[0] packed_elements.append(element) packed_comp = building_blocks.Struct(packed_elements) packed_ref_name = next(name_generator) packed_ref_type = computation_types.to_type(packed_comp.type_signature) packed_ref = building_blocks.Reference(packed_ref_name, packed_ref_type) unpacked_elements = [] for indexes in packed_indexes: group_index = indexes[0] sel = building_blocks.Selection(packed_ref, index=group_index) uri = packed_keys[group_index] called_intrinsics = packed_groups[uri] if len(called_intrinsics) > 1: intrinsic_index = indexes[1] sel = building_blocks.Selection(sel, index=intrinsic_index) unpacked_elements.append(sel) unpacked_comp = building_blocks.Struct(unpacked_elements) variables = comp.result.locals variables[0] = (name, unpacked_comp) variables.insert(0, (packed_ref_name, packed_comp)) block = building_blocks.Block(variables, comp.result.result) fn = building_blocks.Lambda(comp.parameter_name, comp.parameter_type, block) return fn, True
def _call_function(arg_type): """Creates `lambda x: x()` argument type `arg_type`.""" arg_name = next(name_generator) arg_ref = building_blocks.Reference(arg_name, arg_type) called_arg = building_blocks.Call(arg_ref, None) return building_blocks.Lambda(arg_name, arg_type, called_arg)
def test_returns_true_for_lambdas_representing_identical_functions(self): ref_1 = building_blocks.Reference('a', tf.int32) fn_1 = building_blocks.Lambda('a', ref_1.type_signature, ref_1) ref_2 = building_blocks.Reference('b', tf.int32) fn_2 = building_blocks.Lambda('b', ref_2.type_signature, ref_2) self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
def test_returns_false_for_lambdas_with_different_parameter_types(self): ref_1 = building_blocks.Reference('a', tf.int32) fn_1 = building_blocks.Lambda(ref_1.name, ref_1.type_signature, ref_1) ref_2 = building_blocks.Reference('a', tf.float32) fn_2 = building_blocks.Lambda(ref_2.name, ref_2.type_signature, ref_2) self.assertFalse(tree_analysis.trees_equal(fn_1, fn_2))
def test_returns_true_with_excluded_reference(self): ref = building_blocks.Reference('a', tf.int32) fn = building_blocks.Lambda('b', tf.int32, ref) self.assertTrue( tree_analysis.contains_no_unbound_references(fn, excluding='a'))
def test_returns_false(self): ref = building_blocks.Reference('a', tf.int32) fn = building_blocks.Lambda('b', tf.int32, ref) self.assertFalse(tree_analysis.contains_no_unbound_references(fn))
def test_ok_lambda_binding_of_new_variable(self): y_ref = building_blocks.Reference('y', tf.int32) lambda_1 = building_blocks.Lambda('y', tf.int32, y_ref) x_data = building_blocks.Data('x', tf.int32) single_block = building_blocks.Block([('x', x_data)], lambda_1) tree_analysis.check_has_unique_names(single_block)
def test_ok_block_binding_of_new_variable(self): x_data = building_blocks.Data('x', tf.int32) single_block = building_blocks.Block([('x', x_data)], x_data) lambda_1 = building_blocks.Lambda('y', tf.int32, single_block) tree_analysis.check_has_unique_names(lambda_1)
def test_raises_block_rebinding_of_lambda_variable(self): x_data = building_blocks.Data('x', tf.int32) single_block = building_blocks.Block([('x', x_data)], x_data) lambda_1 = building_blocks.Lambda('x', tf.int32, single_block) with self.assertRaises(tree_analysis.NonuniqueNameError): tree_analysis.check_has_unique_names(lambda_1)
def transform_postorder(comp, transform): """Traverses `comp` recursively postorder and replaces its constituents. For each element of `comp` viewed as an expression tree, the transformation `transform` is applied first to building blocks it is parameterized by, then the element itself. The transformation `transform` should act as an identity function on the kinds of elements (computation building blocks) it does not care to transform. This corresponds to a post-order traversal of the expression tree, i.e., parameters are always transformed left-to-right (in the order in which they are listed in building block constructors), then the parent is visited and transformed with the already-visited, and possibly transformed arguments in place. NOTE: In particular, in `Call(f,x)`, both `f` and `x` are arguments to `Call`. Therefore, `f` is transformed into `f'`, next `x` into `x'` and finally, `Call(f',x')` is transformed at the end. Args: comp: A `computation_building_block.ComputationBuildingBlock` to traverse and transform bottom-up. transform: The transformation to apply locally to each building block in `comp`. It is a Python function that accepts a building block at input, and should return a (building block, bool) tuple as output, where the building block is a `computation_building_block.ComputationBuildingBlock` representing either the original building block or a transformed building block and the bool is a flag indicating if the building block was modified as. Returns: The result of applying `transform` to parts of `comp` in a bottom-up fashion, along with a Boolean with the value `True` if `comp` was transformed and `False` if it was not. Raises: TypeError: If the arguments are of the wrong computation_types. NotImplementedError: If the argument is a kind of computation building block that is currently not recognized. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) if isinstance(comp, ( building_blocks.CompiledComputation, building_blocks.Data, building_blocks.Intrinsic, building_blocks.Placement, building_blocks.Reference, )): return transform(comp) elif isinstance(comp, building_blocks.Selection): source, source_modified = transform_postorder(comp.source, transform) if source_modified: comp = building_blocks.Selection(source, comp.name, comp.index) comp, comp_modified = transform(comp) return comp, comp_modified or source_modified elif isinstance(comp, building_blocks.Tuple): elements = [] elements_modified = False for key, value in anonymous_tuple.iter_elements(comp): value, value_modified = transform_postorder(value, transform) elements.append((key, value)) elements_modified = elements_modified or value_modified if elements_modified: comp = building_blocks.Tuple(elements) comp, comp_modified = transform(comp) return comp, comp_modified or elements_modified elif isinstance(comp, building_blocks.Call): fn, fn_modified = transform_postorder(comp.function, transform) if comp.argument is not None: arg, arg_modified = transform_postorder(comp.argument, transform) else: arg, arg_modified = (None, False) if fn_modified or arg_modified: comp = building_blocks.Call(fn, arg) comp, comp_modified = transform(comp) return comp, comp_modified or fn_modified or arg_modified elif isinstance(comp, building_blocks.Lambda): result, result_modified = transform_postorder(comp.result, transform) if result_modified: comp = building_blocks.Lambda(comp.parameter_name, comp.parameter_type, result) comp, comp_modified = transform(comp) return comp, comp_modified or result_modified elif isinstance(comp, building_blocks.Block): variables = [] variables_modified = False for key, value in comp.locals: value, value_modified = transform_postorder(value, transform) variables.append((key, value)) variables_modified = variables_modified or value_modified result, result_modified = transform_postorder(comp.result, transform) if variables_modified or result_modified: comp = building_blocks.Block(variables, result) comp, comp_modified = transform(comp) return comp, comp_modified or variables_modified or result_modified else: raise NotImplementedError( 'Unrecognized computation building block: {}'.format(str(comp)))
def test_returns_true_for_lambdas_referring_to_same_unbound_variables(self): ref_to_x = building_blocks.Reference('x', tf.int32) fn_1 = building_blocks.Lambda('a', tf.int32, ref_to_x) fn_2 = building_blocks.Lambda('a', tf.int32, ref_to_x) self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
def consolidate_and_extract_local_processing(comp, grappler_config_proto): """Consolidates all the local processing in `comp`. The input computation `comp` must have the following properties: 1. The output of `comp` may be of a federated type or unplaced. We refer to the placement `p` of that type as the placement of `comp`. There is no placement anywhere in the body of `comp` different than `p`. If `comp` is of a functional type, and has a parameter, the type of that parameter is a federated type placed at `p` as well, or unplaced if the result of the function is unplaced. 2. The only intrinsics that may appear in the body of `comp` are those that manipulate data locally within the same placement. The exact set of these intrinsics will be gradually updated. At the moment, we support only the following: * Either `federated_apply` or `federated_map`, depending on whether `comp` is `SERVER`- or `CLIENTS`-placed. `federated_map_all_equal` is also allowed in the `CLIENTS`-placed case. * Either `federated_value_at_server` or `federated_value_at_clients`, likewise placement-dependent. * Either `federated_zip_at_server` or `federated_zip_at_clients`, again placement-dependent. Anything else, including `sequence_*` operators, should have been reduced already prior to calling this function. 3. There are no lambdas in the body of `comp` except for `comp` itself being possibly a (top-level) lambda. All other lambdas must have been reduced. This requirement may eventually be relaxed by embedding lambda reducer into this helper method. 4. If `comp` is of a functional type, it is either an instance of `building_blocks.CompiledComputation`, in which case there is nothing for us to do here, or a `building_blocks.Lambda`. 5. There is at most one unbound reference under `comp`, and this is only allowed in the case that `comp` is not of a functional type. Aside from the intrinsics specified above, and the possibility of allowing lambdas, blocks, and references given the constraints above, the remaining constructs in `comp` include a combination of tuples, selections, calls, and sections of TensorFlow (as `CompiledComputation`s). This helper function does contain the logic to consolidate these constructs. The output of this transformation is always a single section of TensorFlow, which we henceforth refer to as `result`, the exact form of which depends on the placement of `comp` and the presence or absence of an argument. a. If there is no argument in `comp`, and `comp` is `SERVER`-placed, then the `result` is such that `comp` can be equivalently represented as: ``` federated_value_at_server(result()) ``` b. If there is no argument in `comp`, and `comp` is `CLIENTS`-placed, then the `result` is such that `comp` can be equivalently represented as: ``` federated_value_at_clients(result()) ``` c. If there is an argument in `comp`, and `comp` is `SERVER`-placed, then the `result` is such that `comp` can be equivalently represented as: ``` (arg -> federated_apply(<result, arg>)) ``` d. If there is an argument in `comp`, and `comp` is `CLIENTS`-placed, then the `result` is such that `comp` can be equivalently represented as: ``` (arg -> federated_map(<result, arg>)) ``` If the type of `comp` is `T@p` (thus `comp` is non-functional), the type of `result` is `T`, where `p` is the specific (concrete) placement of `comp`. If the type of `comp` is `(T@p -> U@p)`, then the type of `result` must be `(T -> U)`, where `p` is again a specific placement. Args: comp: An instance of `building_blocks.ComputationBuildingBlock` that serves as the input to this transformation, as described above. grappler_config_proto: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the generated TensorFlow graph. If `None`, Grappler is bypassed. Returns: An instance of `building_blocks.CompiledComputation` that holds the TensorFlow section produced by this extraction step, as described above. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) if comp.type_signature.is_function(): if comp.is_compiled_computation(): return comp elif not comp.is_lambda(): # We normalize on lambdas for ease of calling unwrap_placement below. # The constructed lambda here simply forwards its argument to `comp`. arg = building_blocks.Reference( next(building_block_factory.unique_name_generator(comp)), comp.type_signature.parameter) called_fn = building_blocks.Call(comp, arg) comp = building_blocks.Lambda(arg.name, arg.type_signature, called_fn) if comp.type_signature.result.is_federated(): unwrapped, _ = tree_transformations.unwrap_placement(comp.result) # Unwrapped can be a call to `federated_value_at_P`, or # `federated_apply/map`. if unwrapped.function.uri in (intrinsic_defs.FEDERATED_APPLY.uri, intrinsic_defs.FEDERATED_MAP.uri): extracted = parse_tff_to_tf(unwrapped.argument[0], grappler_config_proto) check_extraction_result(unwrapped.argument[0], extracted) return extracted else: member_type = None if comp.parameter_type is None else comp.parameter_type.member rebound = building_blocks.Lambda(comp.parameter_name, member_type, unwrapped.argument) extracted = parse_tff_to_tf(rebound, grappler_config_proto) check_extraction_result(rebound, extracted) return extracted else: extracted = parse_tff_to_tf(comp, grappler_config_proto) check_extraction_result(comp, extracted) return extracted elif comp.type_signature.is_federated(): unwrapped, _ = tree_transformations.unwrap_placement(comp) # Unwrapped can be a call to `federated_value_at_P`, or # `federated_apply/map`. if unwrapped.function.uri in (intrinsic_defs.FEDERATED_APPLY.uri, intrinsic_defs.FEDERATED_MAP.uri): extracted = parse_tff_to_tf(unwrapped.argument[0], grappler_config_proto) check_extraction_result(unwrapped.argument[0], extracted) return extracted else: extracted = parse_tff_to_tf(unwrapped.argument, grappler_config_proto) check_extraction_result(unwrapped.argument, extracted) return extracted.function else: called_tf = parse_tff_to_tf(comp, grappler_config_proto) check_extraction_result(comp, called_tf) return called_tf.function
def test_returns_true_for_lambdas(self): ref_1 = building_blocks.Reference('a', tf.int32) fn_1 = building_blocks.Lambda(ref_1.name, ref_1.type_signature, ref_1) ref_2 = building_blocks.Reference('a', tf.int32) fn_2 = building_blocks.Lambda(ref_2.name, ref_2.type_signature, ref_2) self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
def _identity_function(arg_type): """Creates `lambda x: x` with argument type `arg_type`.""" arg_name = next(name_generator) val = building_blocks.Reference(arg_name, arg_type) lam = building_blocks.Lambda(arg_name, arg_type, val) return lam
def test_propogates_dependence_up_through_lambda(self): dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic', tf.int32) lam = building_blocks.Lambda('x', tf.int32, dummy_intrinsic) dependent_nodes = tree_analysis.extract_nodes_consuming( lam, dummy_intrinsic_predicate) self.assertIn(lam, dependent_nodes)
def create_nested_syntax_tree(): r"""Constructs computation with explicit ordering for testing traversals. The goal of this computation is to exercise each switch in transform_postorder_with_symbol_bindings, at least all those that recurse. The computation this function constructs can be represented as below. Notice that the body of the Lambda *does not depend on the Lambda's parameter*, so that if we were actually executing this call the argument will be thrown away. All leaf nodes are instances of `building_blocks.Data`. Call / \ Lambda('arg') Data('k') | Block('y','z')------------- / | ['y'=Data('a'),'z'=Data('b')] | Tuple / \ Block('v') Block('x')------- / \ | | ['v'=Selection] Data('g') ['x'=Data('h'] | | | | | | Block('w') | / \ Tuple ------ ['w'=Data('i'] Data('j') / \ Block('t') Block('u') / \ / \ ['t'=Data('c')] Data('d') ['u'=Data('e')] Data('f') Postorder traversals: If we are reading Data URIs, results of a postorder traversal should be: [a, b, c, d, e, f, g, h, i, j, k] If we are reading locals declarations, results of a postorder traversal should be: [t, u, v, w, x, y, z] And if we are reading both in an interleaved fashion, results of a postorder traversal should be: [a, b, c, d, t, e, f, u, g, v, h, i, j, w, x, y, z, k] Preorder traversals: If we are reading Data URIs, results of a preorder traversal should be: [a, b, c, d, e, f, g, h, i, j, k] If we are reading locals declarations, results of a preorder traversal should be: [y, z, v, t, u, x, w] And if we are reading both in an interleaved fashion, results of a preorder traversal should be: [y, z, a, b, v, t, c, d, u, e, f, g, x, h, w, i, j, k] Since we are also exposing the ability to hook into variable declarations, it is worthwhile considering the order in which variables are assigned in this tree. Notice that this order maps neither to preorder nor to postorder when purely considering the nodes of the tree above. This would be: [arg, y, z, t, u, v, x, w] Returns: An instance of `building_blocks.ComputationBuildingBlock` satisfying the description above. """ data_c = building_blocks.Data('c', tf.float32) data_d = building_blocks.Data('d', tf.float32) left_most_leaf = building_blocks.Block([('t', data_c)], data_d) data_e = building_blocks.Data('e', tf.float32) data_f = building_blocks.Data('f', tf.float32) center_leaf = building_blocks.Block([('u', data_e)], data_f) inner_tuple = building_blocks.Struct([left_most_leaf, center_leaf]) selected = building_blocks.Selection(inner_tuple, index=0) data_g = building_blocks.Data('g', tf.float32) middle_block = building_blocks.Block([('v', selected)], data_g) data_i = building_blocks.Data('i', tf.float32) data_j = building_blocks.Data('j', tf.float32) right_most_endpoint = building_blocks.Block([('w', data_i)], data_j) data_h = building_blocks.Data('h', tf.int32) right_child = building_blocks.Block([('x', data_h)], right_most_endpoint) result = building_blocks.Struct([middle_block, right_child]) data_a = building_blocks.Data('a', tf.float32) data_b = building_blocks.Data('b', tf.float32) dummy_outer_block = building_blocks.Block([('y', data_a), ('z', data_b)], result) dummy_lambda = building_blocks.Lambda('arg', tf.float32, dummy_outer_block) dummy_arg = building_blocks.Data('k', tf.float32) called_lambda = building_blocks.Call(dummy_lambda, dummy_arg) return called_lambda
def test_raises_on_non_tuple_parameter(self): lam = building_blocks.Lambda('x', tf.int32, building_blocks.Reference('x', tf.int32)) with self.assertRaises(TypeError): mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda( lam, [[0]])
def remove_duplicate_called_graphs(comp): """Deduplicates called graphs for a subset of TFF AST constructs. Args: comp: Instance of `building_blocks.ComputationBuildingBlock` whose called graphs we wish to deduplicate, according to `tree_analysis.trees_equal`. For `comp` to be eligible here, it must be either a lambda itself whose body contains no lambdas or blocks, or another computation containing no lambdas or blocks. This restriction is necessary because `remove_duplicate_called_graphs` makes no effort to ensure that it is not pulling references out of their defining scope, except for the case where `comp` is a lambda itself. This function exits early and logs a warning if this assumption is violated. Additionally, `comp` must contain only computations which can be represented in TensorFlow, IE, satisfy the type restriction in `type_utils.is_tensorflow_compatible_type`. Returns: Either a called instance of `building_blocks.CompiledComputation` or a `building_blocks.CompiledComputation` itself, depending on whether `comp` is of non-functional or functional type respectively. Additionally, returns a boolean to match the `transformation_utils.TransformSpec` pattern. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) tree_analysis.check_has_unique_names(comp) name_generator = building_block_factory.unique_name_generator(comp) if isinstance(comp, building_blocks.Lambda): comp_to_check = comp.result else: comp_to_check = comp if tree_analysis.count_types( comp_to_check, (building_blocks.Lambda, building_blocks.Block)) > 0: logging.warning( 'The preprocessors have failed to remove called lambdas ' 'and blocks; falling back to less efficient, but ' 'guaranteed, TensorFlow generation with computation %s.', comp) return comp, False leaf_called_graphs = [] def _pack_called_graphs_into_block(inner_comp): """Packs deduplicated bindings to called graphs in `leaf_called_graphs`.""" if (isinstance(inner_comp, building_blocks.Call) and isinstance( inner_comp.function, building_blocks.CompiledComputation)): for (name, x) in leaf_called_graphs: if tree_analysis.trees_equal(x, inner_comp): return building_blocks.Reference( name, inner_comp.type_signature), True new_name = next(name_generator) leaf_called_graphs.append((new_name, inner_comp)) return building_blocks.Reference(new_name, inner_comp.type_signature), True return inner_comp, False if isinstance(comp, building_blocks.Lambda): transformed_result, _ = transformation_utils.transform_postorder( comp.result, _pack_called_graphs_into_block) packed_into_block = building_blocks.Block(leaf_called_graphs, transformed_result) parsed, _ = create_tensorflow_representing_block(packed_into_block) tff_func = building_blocks.Lambda(comp.parameter_name, comp.parameter_type, parsed) tf_parser_callable = tree_to_cc_transformations.TFParser() comp, _ = tree_transformations.insert_called_tf_identity_at_leaves( tff_func) tf_generated, _ = transformation_utils.transform_postorder( comp, tf_parser_callable) else: transformed_result, _ = transformation_utils.transform_postorder( comp, _pack_called_graphs_into_block) packed_into_block = building_blocks.Block(leaf_called_graphs, transformed_result) tf_generated, _ = create_tensorflow_representing_block( packed_into_block) return tf_generated, True
def test_raises_on_selection_from_non_tuple(self): lam = building_blocks.Lambda( 'x', [tf.int32], building_blocks.Reference('x', [tf.int32])) with self.assertRaisesRegex(TypeError, 'nonexistent index'): mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda( lam, [[0, 0]])
def _create_empty_function(type_elements): ref_name = next(name_generator) ref_type = computation_types.StructType(type_elements) ref = building_blocks.Reference(ref_name, ref_type) empty_tuple = building_blocks.Struct([]) return building_blocks.Lambda(ref.name, ref.type_signature, empty_tuple)
def test_raises_on_non_federated_selection(self): lam = building_blocks.Lambda( 'x', [tf.int32], building_blocks.Reference('x', [tf.int32])) with self.assertRaises(TypeError): mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda( lam, [[0]])
def _extract_update(after_aggregate, grappler_config): """Extracts `update` from `after_aggregate`. This function is intended to be used by `get_canonical_form_for_iterative_process` only. As a result, this function does not assert that `after_aggregate` has the expected structure, the caller is expected to perform these checks before calling this function. Args: after_aggregate: The second result of splitting `after_broadcast` on aggregate intrinsics. grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization. Returns: `update` as specified by `canonical_form.CanonicalForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we extract an AST of the wrong type. """ s7_elements_in_after_aggregate_result = [0, 1] s7_output_extracted = transformations.select_output_from_lambda( after_aggregate, s7_elements_in_after_aggregate_result) s7_output_zipped = building_blocks.Lambda( s7_output_extracted.parameter_name, s7_output_extracted.parameter_type, building_block_factory.create_federated_zip(s7_output_extracted.result)) s6_elements_in_after_aggregate_parameter = [[0, 0, 0], [1, 0], [1, 1]] s6_to_s7_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( s7_output_zipped, s6_elements_in_after_aggregate_parameter).result.function) # TODO(b/148942011): The transformation # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting # from nested structures, therefore we need to pack the type signature # `<s1, s3, s4>` as `<s1, <s3, s4>>`. name_generator = building_block_factory.unique_name_generator( s6_to_s7_computation) pack_ref_name = next(name_generator) pack_ref_type = computation_types.StructType([ s6_to_s7_computation.parameter_type.member[0], computation_types.StructType([ s6_to_s7_computation.parameter_type.member[1], s6_to_s7_computation.parameter_type.member[2], ]), ]) pack_ref = building_blocks.Reference(pack_ref_name, pack_ref_type) sel_s1 = building_blocks.Selection(pack_ref, index=0) sel = building_blocks.Selection(pack_ref, index=1) sel_s3 = building_blocks.Selection(sel, index=0) sel_s4 = building_blocks.Selection(sel, index=1) result = building_blocks.Struct([sel_s1, sel_s3, sel_s4]) pack_fn = building_blocks.Lambda(pack_ref.name, pack_ref.type_signature, result) ref_name = next(name_generator) ref_type = computation_types.FederatedType(pack_ref_type, placements.SERVER) ref = building_blocks.Reference(ref_name, ref_type) unpacked_args = building_block_factory.create_federated_map_or_apply( pack_fn, ref) call = building_blocks.Call(s6_to_s7_computation, unpacked_args) fn = building_blocks.Lambda(ref.name, ref.type_signature, call) return transformations.consolidate_and_extract_local_processing( fn, grappler_config)
def sequence_reduce(value, zero, op): """Reduces a TFF sequence `value` given a `zero` and reduction operator `op`. This method reduces a set of elements of a TFF sequence `value`, using a given `zero` in the algebra (i.e., the result of reducing an empty sequence) of some type `U`, and a reduction operator `op` with type signature `(<U,T> -> U)` that incorporates a single `T`-typed element of `value` into the `U`-typed result of partial reduction. In the special case of `T` equal to `U`, this corresponds to the classical notion of reduction of a set using a commutative associative binary operator. The generalized reduction (with `T` not equal to `U`) requires that repeated application of `op` to reduce a set of `T` always yields the same `U`-typed result, regardless of the order in which elements of `T` are processed in the course of the reduction. One can also invoke `sequence_reduce` on a federated sequence, in which case the reductions are performed pointwise; under the hood, we construct an expression of the form `federated_map(x -> sequence_reduce(x, zero, op), value)`. See also the discussion on `sequence_map`. Note: When applied to a federated value this function does the reduce point-wise. Args: value: A value that is either a TFF sequence, or a federated sequence. zero: The result of reducing a sequence with no elements. op: An operator with type signature `(<U,T> -> U)`, where `T` is the type of the elements of the sequence, and `U` is the type of `zero` to be used in performing the reduction. Returns: The `U`-typed result of reducing elements in the sequence, or if the `value` is federated, a federated `U` that represents the result of locally reducing each member constituent of `value`. Raises: TypeError: If the arguments are not of the types specified above. """ value = value_impl.to_value(value, None) zero = value_impl.to_value(zero, None) op = value_impl.to_value(op, None) # Check if the value is a federated sequence that should be reduced # under a `federated_map`. if value.type_signature.is_federated(): is_federated_sequence = True value_member_type = value.type_signature.member value_member_type.check_sequence() zero_member_type = zero.type_signature.member else: is_federated_sequence = False value.type_signature.check_sequence() if not is_federated_sequence: comp = building_block_factory.create_sequence_reduce( value.comp, zero.comp, op.comp) comp = _bind_comp_as_reference(comp) return value_impl.Value(comp) else: ref_type = computation_types.StructType( [value_member_type, zero_member_type]) ref = building_blocks.Reference('arg', ref_type) arg1 = building_blocks.Selection(ref, index=0) arg2 = building_blocks.Selection(ref, index=1) call = building_block_factory.create_sequence_reduce( arg1, arg2, op.comp) fn = building_blocks.Lambda(ref.name, ref.type_signature, call) fn_value_impl = value_impl.Value(fn) args = building_blocks.Struct([value.comp, zero.comp]) return federated_map(fn_value_impl, args)
def test_returns_false(self): ref = building_blocks.Reference('a', tf.int32) fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) self.assertTrue(tree_analysis.contains_no_unbound_references(fn))
def test_raises_on_nested_lambdas_with_same_variable_name(self): ref_to_x = building_blocks.Reference('x', tf.int32) lambda_1 = building_blocks.Lambda('x', tf.int32, ref_to_x) lambda_2 = building_blocks.Lambda('x', tf.int32, lambda_1) with self.assertRaises(tree_analysis.NonuniqueNameError): tree_analysis.check_has_unique_names(lambda_2)