def check_contains_no_new_unbound_references(old_tree, new_tree): """Checks that `new_tree` contains no unbound references not in `old_tree`.""" old_unbound = transformation_utils.get_map_of_unbound_references( old_tree)[old_tree] new_unbound = transformation_utils.get_map_of_unbound_references( new_tree)[new_tree] diff = new_unbound - old_unbound if diff: raise ValueError('Expected no new unbounded references. ' f'Old tree:\n{old_tree}\nNew tree:\n{new_tree}\n' f'New unbound references: {diff}')
def _get_unbound_ref(block): """Helper to get unbound ref name and type spec if it exists in `block`.""" all_unbound_refs = transformation_utils.get_map_of_unbound_references(block) top_level_unbound_ref = all_unbound_refs[block] num_unbound_refs = len(top_level_unbound_ref) if num_unbound_refs == 0: return None elif num_unbound_refs > 1: raise ValueError('`create_tensorflow_representing_block` must be passed ' 'a block with at most a single unbound reference; ' 'encountered the block {} with {} unbound ' 'references.'.format(block, len(top_level_unbound_ref))) unbound_ref_name = top_level_unbound_ref.pop() top_level_type_spec = None def _get_unbound_ref_type_spec(inner_comp): if (inner_comp.is_reference() and inner_comp.name == unbound_ref_name): nonlocal top_level_type_spec top_level_type_spec = inner_comp.type_signature return inner_comp, False transformation_utils.transform_postorder(block, _get_unbound_ref_type_spec) return building_blocks.Reference(unbound_ref_name, top_level_type_spec)
def check_has_unique_names(comp): """Checks that each variable of `comp` is bound at most once. Additionally, checks that `comp` does not mask any names which are unbound at the top level. Args: comp: Instance of `building_blocks.ComputationBuildingBlock`. Raises: NonuniqueNameError: If we encounter a name that is bound multiple times or a binding which would shadow an unbound reference. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) # Initializing `names` to unbound names in `comp` ensures that `comp` does not # mask any names from its parent scope. names = transformation_utils.get_map_of_unbound_references(comp)[comp] def _visit_name(name): if name in names: raise NonuniqueNameError(comp, name) names.add(name) def _visit(comp): if comp.is_block(): for name, _ in comp.locals: _visit_name(name) elif comp.is_lambda() and comp.parameter_type is not None: _visit_name(comp.parameter_name) visit_postorder(comp, _visit)
def transform(self, comp): if not self.should_transform(comp): return comp, False unbound_ref_set = transformation_utils.get_map_of_unbound_references( comp.result)[comp.result] if (not unbound_ref_set) or (not comp.locals): return comp.result, True new_locals = [] for name, val in reversed(comp.locals): if name in unbound_ref_set: new_locals.append((name, val)) unbound_ref_set = unbound_ref_set.union( transformation_utils.get_map_of_unbound_references(val)[val]) unbound_ref_set.discard(name) if len(new_locals) == len(comp.locals): return comp, False elif not new_locals: return comp.result, True return building_blocks.Block(reversed(new_locals), comp.result), True
def __call__(self, proto: pb.Computation) -> Set[str]: """Returns the names of any unbound references in `proto`.""" py_typecheck.check_type(proto, pb.Computation) evaluated = self._evaluated_comps.get(_hash_proto(proto)) if evaluated is not None: return evaluated tree = building_blocks.ComputationBuildingBlock.from_proto(proto) unbound_ref_map = transformation_utils.get_map_of_unbound_references(tree) self._evaluated_comps.update( {_hash_proto(k.proto): v for k, v in unbound_ref_map.items()}) return unbound_ref_map[tree]
def _inline_block_variables_required_to_align_intrinsics(comp, uri): """Inlines the variables required to align the intrinsic for the given `uri`. This function inlines only the block variables required to align an intrinsic, which is necessary because many transformations insert block variables that do not impact alignment and should not be inlined. Additionally, this function iteratively attempts to inline block variables a long as the intrinsic can not be extracted to the top level lambda. Meaning, that unbound references in variables that are inlined, will also be inlined. Args: comp: The `building_blocks.Lambda` to transform. uri: A Python `list` of URI of intrinsics. Returns: A new computation with the transformation applied or the original `comp`. Raises: ValueError: If an there are unbound references, other than block variables, preventing an intrinsic with the given `uri` from being aligned. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(uri, list) for x in uri: py_typecheck.check_type(x, str) while not _can_extract_intrinsics_to_top_level_lambda(comp, uri): unbound_references = transformation_utils.get_map_of_unbound_references( comp) variable_names = set() intrinsics = _get_called_intrinsics(comp, uri) for intrinsic in intrinsics: names = unbound_references[intrinsic] names.discard(comp.parameter_name) variable_names.update(names) if not variable_names: raise tree_transformations.TransformationError( 'Inlining `Block` variables has failed. Expected to find unbound ' 'references for called `Intrisic`s matching the URI: \'{}\', but ' 'none were found in the AST: \n{}'.format( uri, comp.formatted_representation())) comp, modified = tree_transformations.inline_block_locals( comp, variable_names=variable_names) if modified: comp, _ = tree_transformations.uniquify_reference_names(comp) else: raise tree_transformations.TransformationError( 'Inlining `Block` variables has failed, this will result in an ' 'infinite loop. Expected to modify the AST by inlining the variable ' 'names: \'{}\', but no transformations to the AST: \n{}'. format(variable_names, comp.formatted_representation())) return comp
def test_binding_single_arg_leaves_no_unbound_references(self): fed_at_clients = computation_types.FederatedType(tf.int32, placements.CLIENTS) fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER) tuple_of_federated_types = computation_types.NamedTupleType( [fed_at_clients, fed_at_server]) lam = building_blocks.Lambda( 'x', tuple_of_federated_types, building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=0)) zeroth_index_extracted = transformations.zip_selection_as_argument_to_lower_level_lambda( lam, [[0]]) unbound_references = transformation_utils.get_map_of_unbound_references( zeroth_index_extracted)[zeroth_index_extracted] self.assertEmpty(unbound_references)
def test_single_element_selection_leaves_no_unbound_references(self): fed_at_clients = computation_types.FederatedType(tf.int32, placements.CLIENTS) fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER) tuple_of_federated_types = computation_types.StructType( [fed_at_clients, fed_at_server]) lam = building_blocks.Lambda( 'x', tuple_of_federated_types, building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=0)) new_lam = form_utils._as_function_of_some_federated_subparameters( lam, [(0,)]) unbound_references = transformation_utils.get_map_of_unbound_references( new_lam)[new_lam] self.assertEmpty(unbound_references)
def should_transform(self, comp): if not (type_analysis.is_tensorflow_compatible_type(comp.type_signature) or (comp.type_signature.is_function() and type_analysis.is_tensorflow_compatible_type( comp.type_signature.parameter) and type_analysis.is_tensorflow_compatible_type( comp.type_signature.result))): return False elif comp.is_compiled_computation() or ( comp.is_call() and comp.function.is_compiled_computation()): # These represent the final result of TF generation; no need to transform. return False unbound_refs = transformation_utils.get_map_of_unbound_references( comp)[comp] if unbound_refs: # We cannot represent these captures without further information. return False if tree_analysis.contains_types(comp, building_blocks.Intrinsic): return False return True
def compile_local_subcomputations_to_tensorflow( comp: building_blocks.ComputationBuildingBlock, ) -> building_blocks.ComputationBuildingBlock: """Compiles subcomputations to TensorFlow where possible.""" comp = unpack_compiled_computations(comp) local_cache = {} def _is_local(comp): cached = local_cache.get(comp, None) if cached is not None: return cached if (comp.is_intrinsic() or comp.is_data() or comp.is_placement() or type_analysis.contains_federated_types(comp.type_signature)): local_cache[comp] = False return False if (comp.is_compiled_computation() and comp.proto.WhichOneof('computation') == 'xla'): local_cache[comp] = False return False for child in comp.children(): if not _is_local(child): local_cache[comp] = False return False return True unbound_ref_map = transformation_utils.get_map_of_unbound_references(comp) def _compile_if_local(comp): if _is_local(comp) and not unbound_ref_map[comp]: return compile_local_computation_to_tensorflow(comp), True return comp, False # Note: this transformation is preorder so that local subcomputations are not # first transformed to TensorFlow if they have a parent local computation # which could have instead been transformed into a larger single block of # TensorFlow. comp, _ = transformation_utils.transform_preorder(comp, _compile_if_local) return comp
def contains_no_unbound_references(tree, excluding=None): """Tests if all the references in `tree` are bound by `tree`. Args: tree: Instance of `building_blocks.ComputationBuildingBlock` to view as an abstract syntax tree. excluding: A `string` or a collection of `string`s representing the names of references to exclude from the test. Returns: `True` if there are no unbound references in `tree` excluding those specified by `excluding`, otherwise `False`. """ py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock) if isinstance(excluding, str): excluding = [excluding] unbound_references = transformation_utils.get_map_of_unbound_references(tree) if excluding is not None: excluding = set(excluding) names = unbound_references[tree] - excluding else: names = unbound_references[tree] return len(names) == 0 # pylint: disable=g-explicit-length-test
def test_returns_tree(self): ip = get_iterative_process_for_sum_example_with_no_federated_secure_sum( ) next_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) next_tree = canonical_form_utils._replace_intrinsics_with_bodies( next_tree) before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_secure_sum( next_tree) before_federated_aggregate, after_federated_aggregate = ( transformations.force_align_and_split_by_intrinsics( next_tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri])) self.assertIsInstance(before_aggregate, building_blocks.Lambda) self.assertIsInstance(before_aggregate.result, building_blocks.Struct) self.assertLen(before_aggregate.result, 2) # trees_equal will fail if computations refer to unbound references, so we # create a new dummy computation to bind them. unbound_refs_in_before_agg_result = transformation_utils.get_map_of_unbound_references( before_aggregate.result[0])[before_aggregate.result[0]] unbound_refs_in_before_fed_agg_result = transformation_utils.get_map_of_unbound_references( before_federated_aggregate.result)[ before_federated_aggregate.result] dummy_data = building_blocks.Data('data', computation_types.AbstractType('T')) blk_binding_refs_in_before_agg = building_blocks.Block( [(name, dummy_data) for name in unbound_refs_in_before_agg_result], before_aggregate.result[0]) blk_binding_refs_in_before_fed_agg = building_blocks.Block( [(name, dummy_data) for name in unbound_refs_in_before_fed_agg_result], before_federated_aggregate.result) self.assertTrue( tree_analysis.trees_equal(blk_binding_refs_in_before_agg, blk_binding_refs_in_before_fed_agg)) # pyformat: disable self.assertEqual( before_aggregate.result[1].formatted_representation(), '<\n' ' federated_value_at_clients(<>),\n' ' <>\n' '>') # pyformat: enable self.assertIsInstance(after_aggregate, building_blocks.Lambda) self.assertIsInstance(after_aggregate.result, building_blocks.Call) self.assertTrue( tree_analysis.trees_equal(after_aggregate.result.function, after_federated_aggregate)) # pyformat: disable self.assertEqual( after_aggregate.result.argument.formatted_representation(), '<\n' ' _var1[0],\n' ' _var1[1][0]\n' '>')