def prepare_for_rebinding(comp): """Prepares `comp` for extracting rebound variables. Currently, this means replacing all called lambdas and inlining all blocks. This does not necessarly guarantee that the resulting computation has no called lambdas, it merely reduces a level of indirection here. This reduction has proved sufficient for identifying variables which are about to be rebound in the top-level lambda, necessarily when compiler components factor work out from a single function into multiple functions. Since this function makes no guarantees about sufficiency, it is the responsibility of the caller to ensure that no unbound variables are introduced during the rebinding. Args: comp: Instance of `building_blocks.ComputationBuildingBlock` from which all occurrences of a given variable need to be extracted and rebound. Returns: Another instance of `building_blocks.ComputationBuildingBlock` which has had all called lambdas replaced by blocks, all blocks inlined and all selections from tuples collapsed. """ # TODO(b/146430051): Follow up here and consider removing or enforcing more # strict output invariants when `remove_called_lambdas_and_blocks` is moved # in here. py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) comp, _ = tree_transformations.uniquify_reference_names(comp) comp, _ = tree_transformations.replace_called_lambda_with_block(comp) block_inliner = tree_transformations.InlineBlock(comp) selection_replacer = tree_transformations.ReplaceSelectionFromTuple() transforms = [block_inliner, selection_replacer] symbol_tree = transformation_utils.SymbolTree( transformation_utils.ReferenceCounter) def _transform_fn(comp, symbol_tree): """Transform function chaining inlining and collapsing selections.""" modified = False for transform in transforms: if transform.global_transform: comp, transform_modified = transform.transform( comp, symbol_tree) else: comp, transform_modified = transform.transform(comp) modified = modified or transform_modified return comp, modified return transformation_utils.transform_postorder_with_symbol_bindings( comp, _transform_fn, symbol_tree)
def _check_calls_are_concrete(comp): """Encodes condition for completeness of direct extraction of calls. After checking this condition, all functions which are semantically called (IE, functions which will be invoked eventually by running the computation) are called directly, and we can simply extract them by pattern-matching on `building_blocks.Call`. Args: comp: Instance of `building_blocks.ComputationBuildingBlock` to check for condition that functional argument of `Call` constructs contains only the enumeration in the top-level docstring. Raises: ValueError: If `comp` fails this condition. """ symbol_tree = transformation_utils.SymbolTree( transformation_utils.ReferenceCounter) def _check_for_call_arguments(comp_to_check, symbol_tree): if not comp_to_check.is_call(): return comp_to_check, False functional_arg = comp_to_check.function if functional_arg.is_compiled_computation( ) or functional_arg.is_intrinsic(): return comp_to_check, False elif functional_arg.is_lambda(): if type_analysis.contains(functional_arg.type_signature.result, lambda x: x.is_function()): raise ValueError('Called higher-order functions are disallowed in ' 'transforming to call-dominant form, as they may ' 'break the reliance on pattern-matching to extract ' 'called intrinsics. Encountered a call to the' 'lambda {l} with type signature {t}.'.format( l=functional_arg, t=functional_arg.type_signature)) return comp_to_check, False elif functional_arg.is_reference(): # This case and the following handle the possibility that a lambda # declares a functional parameter, and this parameter is invoked in its # body. payload = symbol_tree.get_payload_with_name(functional_arg.name) if payload is None: return comp, False if payload.value is not None: raise ValueError('Called references which are not bound to lambda ' 'parameters are disallowed in transforming to ' 'call-dominant form, as they may break the reliance ' 'on pattern-matching to extract called intrinsics. ' 'Encountered a call to the reference {r}, which is ' 'bound to the value {v} in this computation.'.format( r=functional_arg, v=payload.value)) elif functional_arg.is_selection(): concrete_source = functional_arg.source while concrete_source.is_selection(): concrete_source = concrete_source.source if concrete_source.is_reference(): payload = symbol_tree.get_payload_with_name(concrete_source.name) if payload is None: return comp, False if payload.value is not None: raise ValueError('Called selections from references which are not ' 'bound to lambda parameters are disallowed in ' 'transforming to call-dominant form, as they may ' 'break the reliance on pattern-matching to ' 'extract called intrinsics. Encountered a call to ' 'the reference {r}, which is bound to the value ' '{v} in this computation.'.format( r=functional_arg, v=payload.value)) return comp, False else: raise ValueError('Called selections are only permitted in ' 'transforming to call-comiunant form the case that ' 'they select from lambda parameters; encountered a ' 'call to selection {s}.'.format(s=functional_arg)) else: raise ValueError('During transformation to call-dominant form, we rely ' 'on the assumption that all called functions are ' 'either: compiled computations; intrinsics; lambdas ' 'with nonfuntional return types; or selections from ' 'lambda parameters. Encountered the called function ' '{f} of type {t}.'.format( f=functional_arg, t=type(functional_arg))) transformation_utils.transform_postorder_with_symbol_bindings( comp, _check_for_call_arguments, symbol_tree)
def extract_nodes_consuming(tree, predicate): """Returns the set of AST nodes which consume nodes matching `predicate`. Notice we adopt the convention that a node which itself satisfies the predicate is in this set. Args: tree: Instance of `building_blocks.ComputationBuildingBlock` to view as an abstract syntax tree, and construct the set of nodes in this tree having a dependency on nodes matching `predicate`; that is, the set of nodes whose value depends on evaluating nodes matching `predicate`. predicate: One-arg callable, accepting arguments of type `building_blocks.ComputationBuildingBlock` and returning a `bool` indicating match or mismatch with the desired pattern. Returns: A `set` of `building_blocks.ComputationBuildingBlock` instances representing the nodes in `tree` dependent on nodes matching `predicate`. """ py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock) py_typecheck.check_callable(predicate) class _NodeSet: def __init__(self): self.mapping = {} def add(self, comp): self.mapping[id(comp)] = comp def to_set(self): return set(self.mapping.values()) dependent_nodes = _NodeSet() def _are_children_in_dependent_set(comp, symbol_tree): """Checks if the dependencies of `comp` are present in `dependent_nodes`.""" if (comp.is_intrinsic() or comp.is_data() or comp.is_placement() or comp.is_compiled_computation()): return False elif comp.is_lambda(): return id(comp.result) in dependent_nodes.mapping elif comp.is_block(): return any( id(x[1]) in dependent_nodes.mapping for x in comp.locals) or id( comp.result) in dependent_nodes.mapping elif comp.is_struct(): return any(id(x) in dependent_nodes.mapping for x in comp) elif comp.is_selection(): return id(comp.source) in dependent_nodes.mapping elif comp.is_call(): return id(comp.function) in dependent_nodes.mapping or id( comp.argument) in dependent_nodes.mapping elif comp.is_reference(): return _is_reference_dependent(comp, symbol_tree) def _is_reference_dependent(comp, symbol_tree): payload = symbol_tree.get_payload_with_name(comp.name) if payload is None: return False # The postorder traversal ensures that we process any # bindings before we process the reference to those bindings return id(payload.value) in dependent_nodes.mapping def _populate_dependent_set(comp, symbol_tree): """Populates `dependent_nodes` with all nodes dependent on `predicate`.""" if predicate(comp): dependent_nodes.add(comp) elif _are_children_in_dependent_set(comp, symbol_tree): dependent_nodes.add(comp) return comp, False symbol_tree = transformation_utils.SymbolTree( transformation_utils.ReferenceCounter) transformation_utils.transform_postorder_with_symbol_bindings( tree, _populate_dependent_set, symbol_tree) return dependent_nodes.to_set()
def extract_nodes_consuming(tree, predicate): """Returns the set of AST nodes which consume nodes matching `predicate`. Notice we adopt the convention that a node which itself satisfies the predicate is in this set. Args: tree: Instance of `building_blocks.ComputationBuildingBlock` to view as an abstract syntax tree, and construct the set of nodes in this tree having a dependency on nodes matching `predicate`; that is, the set of nodes whose value depends on evaluating nodes matching `predicate`. predicate: One-arg callable, accepting arguments of type `building_blocks.ComputationBuildingBlock` and returning a `bool` indicating match or mismatch with the desired pattern. Returns: A `set` of `building_blocks.ComputationBuildingBlock` instances representing the nodes in `tree` dependent on nodes matching `predicate`. """ py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock) py_typecheck.check_callable(predicate) dependent_nodes = set() def _are_children_in_dependent_set(comp, symbol_tree): """Checks if the dependencies of `comp` are present in `dependent_nodes`.""" if isinstance( comp, (building_blocks.Intrinsic, building_blocks.Data, building_blocks.Placement, building_blocks.CompiledComputation)): return False elif isinstance(comp, building_blocks.Lambda): return comp.result in dependent_nodes elif isinstance(comp, building_blocks.Block): return any(x[1] in dependent_nodes for x in comp.locals) or comp.result in dependent_nodes elif isinstance(comp, building_blocks.Tuple): return any(x in dependent_nodes for x in comp) elif isinstance(comp, building_blocks.Selection): return comp.source in dependent_nodes elif isinstance(comp, building_blocks.Call): return comp.function in dependent_nodes or comp.argument in dependent_nodes elif isinstance(comp, building_blocks.Reference): return _is_reference_dependent(comp, symbol_tree) def _is_reference_dependent(comp, symbol_tree): payload = symbol_tree.get_payload_with_name(comp.name) if payload is None: return False # The postorder traversal ensures that we process any # bindings before we process the reference to those bindings return payload.value in dependent_nodes def _populate_dependent_set(comp, symbol_tree): """Populates `dependent_nodes` with all nodes dependent on `predicate`.""" if predicate(comp): dependent_nodes.add(comp) elif _are_children_in_dependent_set(comp, symbol_tree): dependent_nodes.add(comp) return comp, False symbol_tree = transformation_utils.SymbolTree( transformation_utils.ReferenceCounter) transformation_utils.transform_postorder_with_symbol_bindings( tree, _populate_dependent_set, symbol_tree) return dependent_nodes
def remove_lambdas_and_blocks(comp): """Removes any called lambdas and blocks from `comp`. This function will rename all the variables in `comp` in a single walk of the AST, then replace called lambdas with blocks in another walk, since this transformation interacts with scope in delicate ways. It will chain inlining the blocks and collapsing the selection-from-tuple pattern together into a final pass. Args: comp: Instance of `building_blocks.ComputationBuildingBlock` from which we want to remove called lambdas and blocks. Returns: A transformed version of `comp` which has no called lambdas or blocks, and no extraneous selections from tuples. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) # TODO(b/146904968): In general, any bounded number of passes of these # transforms as currently implemented is insufficient in order to satisfy # the purpose of this function. Filing a new bug to followup if this becomes a # pressing issue. modified = False for fn in [ transformations.remove_unused_block_locals, transformations.inline_selections_from_tuple, transformations.replace_called_lambda_with_block, ] * 2: comp, inner_modified = fn(comp) modified = inner_modified or modified for fn in [ transformations.remove_unused_block_locals, transformations.uniquify_reference_names, ]: comp, inner_modified = fn(comp) modified = inner_modified or modified block_inliner = transformations.InlineBlock(comp) selection_replacer = transformations.ReplaceSelectionFromTuple() transforms = [block_inliner, selection_replacer] def _transform_fn(comp, symbol_tree): """Transform function chaining inlining and collapsing selections. This function is inlined here as opposed to factored out and parameterized by the transforms to apply, due to the delicacy of chaining transformations which rely on state. These transformations should be safe if they appear first in the list of transforms, but due to the difficulty of reasoning about the invariants the transforms can rely on in this setting, there is no function exposed which hoists out the internal logic. Args: comp: Instance of `building_blocks.ComputationBuildingBlock` we wish to check for inlining and collapsing of selections. symbol_tree: Instance of `building_blocks.SymbolTree` defining the bindings available to `comp`. Returns: A transformed version of `comp`. """ modified = False for transform in transforms: if transform.global_transform: comp, transform_modified = transform.transform(comp, symbol_tree) else: comp, transform_modified = transform.transform(comp) modified = modified or transform_modified return comp, modified symbol_tree = transformation_utils.SymbolTree( transformation_utils.ReferenceCounter) transformed_comp, inner_modified = transformation_utils.transform_postorder_with_symbol_bindings( comp, _transform_fn, symbol_tree) modified = modified or inner_modified return transformed_comp, modified
def uniquify_reference_names(comp, name_generator=None): """Replaces all the bound reference names in `comp` with unique names. Notice that `uniquify_reference_names` simply leaves alone any reference which is unbound under `comp`. Args: comp: The computation building block in which to perform the replacements. name_generator: An optional generator to use for creating unique names. If `name_generator` is not None, all existing bindings will be replaced. Returns: Returns a transformed version of comp inside of which all variable names are guaranteed to be unique, and are guaranteed to not mask any unbound names referenced in the body of `comp`. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) # Passing `comp` to `unique_name_generator` here will ensure that the # generated names conflict with neither bindings in `comp` nor unbound # references in `comp`. if name_generator is None: name_generator = building_block_factory.unique_name_generator(comp) rename_all = False else: # If a `name_generator` was passed in, all bindings must be renamed since # we need to avoid duplication with an outer scope. rename_all = True used_names = set() class _RenameNode(transformation_utils.BoundVariableTracker): """transformation_utils.SymbolTree node for renaming References in ASTs.""" def __init__(self, name, value): super().__init__(name, value) py_typecheck.check_type(name, str) if rename_all or name in used_names: self.new_name = next(name_generator) else: self.new_name = name used_names.add(self.new_name) def __str__(self): return 'Value: {}, name: {}, new_name: {}'.format(self.value, self.name, self.new_name) def _transform(comp, context_tree): """Renames References in `comp` to unique names.""" if comp.is_reference(): payload = context_tree.get_payload_with_name(comp.name) if payload is None: return comp, False new_name = payload.new_name if new_name is comp.name: return comp, False return building_blocks.Reference(new_name, comp.type_signature, comp.context), True elif comp.is_block(): new_locals = [] modified = False for name, val in comp.locals: context_tree.walk_down_one_variable_binding() new_name = context_tree.get_payload_with_name(name).new_name modified = modified or (new_name is not name) new_locals.append((new_name, val)) return building_blocks.Block(new_locals, comp.result), modified elif comp.is_lambda(): if comp.parameter_type is None: return comp, False context_tree.walk_down_one_variable_binding() new_name = context_tree.get_payload_with_name( comp.parameter_name).new_name if new_name is comp.parameter_name: return comp, False return building_blocks.Lambda(new_name, comp.parameter_type, comp.result), True return comp, False symbol_tree = transformation_utils.SymbolTree(_RenameNode) return transformation_utils.transform_postorder_with_symbol_bindings( comp, _transform, symbol_tree)