def inline_block_locals(comp, variable_names=None): """Inlines the block variables in `comp` whitelisted by `variable_names`. Args: comp: The computation building block in which to perform the extractions. The names of lambda parameters and block variables in `comp` must be unique. variable_names: A Python list, tuple, or set representing the whitelist of variable names to inline; or None if all variables should be inlined. Returns: A new computation with the transformation applied or the original `comp`. Raises: ValueError: If `comp` contains variables with non-unique names. """ py_typecheck.check_type( comp, computation_building_blocks.ComputationBuildingBlock) _check_has_unique_names(comp) if variable_names is not None: py_typecheck.check_type(variable_names, (list, tuple, set)) def _should_inline_variable(name): return variable_names is None or name in variable_names def _should_transform(comp): return ((isinstance(comp, computation_building_blocks.Reference) and _should_inline_variable(comp.name)) or (isinstance(comp, computation_building_blocks.Block) and any( _should_inline_variable(name) for name, _ in comp.locals))) def _transform(comp, symbol_tree): """Returns a new transformed computation or `comp`.""" if not _should_transform(comp): return comp, False if isinstance(comp, computation_building_blocks.Reference): value = symbol_tree.get_payload_with_name(comp.name).value # This identifies a variable bound by a Block as opposed to a Lambda. if value is not None: return value, True else: return comp, False elif isinstance(comp, computation_building_blocks.Block): variables = [(name, value) for name, value in comp.locals if not _should_inline_variable(name)] if not variables: comp = comp.result else: comp = computation_building_blocks.Block( variables, comp.result) return comp, True return comp, False symbol_tree = transformation_utils.SymbolTree( transformation_utils.ReferenceCounter) return transformation_utils.transform_postorder_with_symbol_bindings( comp, _transform, symbol_tree)
def uniquify_references(comp): """Gives globally unique names to locally scoped names under `comp`. Args: comp: Instance of `computation_building_blocks.ComputationBuildingBlock`, representing the root of the AST in which we are hoping to rename all references. Returns: Returns a transformed version of comp inside of which all variable names are guaranteed to be unique. """ int_sequence = itertools.count(start=1) class _RenameNode(transformation_utils.BoundVariableTracker): """transformation_utils.SymbolTree node for renaming References in ASTs.""" def __init__(self, name, value): super(_RenameNode, self).__init__(name, value) py_typecheck.check_type(name, str) self.new_name = '_variable{}'.format(six.next(int_sequence)) def __str__(self): return 'Value: {}, name: {}, new_name: {}'.format(self.value, self.name, self.new_name) def transform(comp, context_tree): """Renames References in `comp` to unique names.""" if isinstance(comp, computation_building_blocks.Reference): new_name = context_tree.get_payload_with_name(comp.name).new_name return computation_building_blocks.Reference(new_name, comp.type_signature, comp.context) elif isinstance(comp, computation_building_blocks.Block): new_locals = [] for name, val in comp.locals: context_tree.walk_down_one_variable_binding() new_name = context_tree.get_payload_with_name(name).new_name new_locals.append((new_name, val)) return computation_building_blocks.Block(new_locals, comp.result) elif isinstance(comp, computation_building_blocks.Lambda): context_tree.walk_down_one_variable_binding() new_name = context_tree.get_payload_with_name( comp.parameter_name).new_name return computation_building_blocks.Lambda(new_name, comp.parameter_type, comp.result) return comp rename_tree = transformation_utils.SymbolTree(_RenameNode) new_comp = transformation_utils.transform_postorder_with_symbol_bindings( comp, transform, rename_tree) return new_comp
def uniquify_reference_names(comp): """Replaces all the reference names in `comp` with unique names. Args: comp: The computation building block in which to perform the replacements. Returns: Returns a transformed version of comp inside of which all variable names are guaranteed to be unique. """ py_typecheck.check_type(comp, computation_building_blocks.ComputationBuildingBlock) name_generator = computation_constructing_utils.unique_name_generator(None) class _RenameNode(transformation_utils.BoundVariableTracker): """transformation_utils.SymbolTree node for renaming References in ASTs.""" def __init__(self, name, value): super(_RenameNode, self).__init__(name, value) py_typecheck.check_type(name, str) self.new_name = six.next(name_generator) def __str__(self): return 'Value: {}, name: {}, new_name: {}'.format(self.value, self.name, self.new_name) def _transform(comp, context_tree): """Renames References in `comp` to unique names.""" if isinstance(comp, computation_building_blocks.Reference): new_name = context_tree.get_payload_with_name(comp.name).new_name return computation_building_blocks.Reference(new_name, comp.type_signature, comp.context), True elif isinstance(comp, computation_building_blocks.Block): new_locals = [] for name, val in comp.locals: context_tree.walk_down_one_variable_binding() new_name = context_tree.get_payload_with_name(name).new_name new_locals.append((new_name, val)) return computation_building_blocks.Block(new_locals, comp.result), True elif isinstance(comp, computation_building_blocks.Lambda): context_tree.walk_down_one_variable_binding() new_name = context_tree.get_payload_with_name( comp.parameter_name).new_name return computation_building_blocks.Lambda(new_name, comp.parameter_type, comp.result), True return comp, False symbol_tree = transformation_utils.SymbolTree(_RenameNode) return transformation_utils.transform_postorder_with_symbol_bindings( comp, _transform, symbol_tree)
def inline_block_locals(comp): """Inlines all block local variables. Since this transform is not necessarily safe, it should only be calles if all references under `comp` have unique names. Args: comp: Instance of `computation_building_blocks.ComputationBuildingBlock` whose blocks we wish to inline. Returns: A possibly different `computation_building_blocks.ComputationBuildingBlock` containing the same logic as `comp`, but with all blocks inlined. Raises: ValueError: If `comp` has variables with non-unique names. """ py_typecheck.check_type( comp, computation_building_blocks.ComputationBuildingBlock) if not transformation_utils.has_unique_names(comp): raise ValueError( '`inline_block_locals` should only be called after we ' 'have uniquified all ' '`computation_building_blocks.Reference` names, since we ' 'may be moving computations with unbound references ' 'under constructs which bind those references.') def _transform(comp, symbol_tree): """Inline transform function.""" if isinstance(comp, computation_building_blocks.Reference): value_to_use = symbol_tree.get_payload_with_name(comp.name).value if value_to_use is not None: # This identifies a variable bound by a Block as opposed to a Lambda. return value_to_use else: return comp elif isinstance(comp, computation_building_blocks.Block): # All locals have been inlined, so the block is equivalent to its result. return comp.result return comp empty_tree = transformation_utils.SymbolTree( transformation_utils.ReferenceCounter) return transformation_utils.transform_postorder_with_symbol_bindings( comp, _transform, empty_tree)
def inline_block_locals(comp): """Inlines all block local variables. Since this transform is not necessarily safe, it should only be calles if all references under `comp` have unique names. Args: comp: Instance of `computation_building_blocks.ComputationBuildingBlock` whose blocks we wish to inline. Returns: A possibly different `computation_building_blocks.ComputationBuildingBlock` containing the same logic as `comp`, but with all blocks inlined. Raises: ValueError: If `comp` has variables with non-unique names. """ py_typecheck.check_type(comp, computation_building_blocks.ComputationBuildingBlock) _check_has_unique_names(comp) def _transform(comp, symbol_tree): """Inline transform function.""" if isinstance(comp, computation_building_blocks.Reference): value = symbol_tree.get_payload_with_name(comp.name).value # This identifies a variable bound by a Block as opposed to a Lambda. if value is not None: return value, True else: return comp, False elif isinstance(comp, computation_building_blocks.Block): return comp.result, True return comp, False symbol_tree = transformation_utils.SymbolTree( transformation_utils.ReferenceCounter) return transformation_utils.transform_postorder_with_symbol_bindings( comp, _transform, symbol_tree)
def extract_nodes_consuming(tree, predicate): """Returns the set of AST nodes which consume nodes matching `predicate`. Notice we adopt the convention that a node which itself satisfies the predicate is in this set. Args: tree: Instance of `computation_building_blocks.ComputationBuildingBlock` to view as an abstract syntax tree, and construct the set of nodes in this tree having a dependency on nodes matching `predicate`; that is, the set of nodes whose value depends on evaluating nodes matching `predicate`. predicate: One-arg callable, accepting arguments of type `computation_building_blocks.ComputationBuildingBlock` and returning a `bool` indicating match or mismatch with the desired pattern. Returns: A `set` of `computation_building_blocks.ComputationBuildingBlock` instances representing the nodes in `tree` dependent on nodes matching `predicate`. """ py_typecheck.check_type(tree, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_callable(predicate) dependent_nodes = set() def _are_children_in_dependent_set(comp, symbol_tree): """Checks if the dependencies of `comp` are present in `dependent_nodes`.""" if isinstance(comp, (computation_building_blocks.Intrinsic, computation_building_blocks.Data, computation_building_blocks.Placement, computation_building_blocks.CompiledComputation)): return False elif isinstance(comp, computation_building_blocks.Lambda): return comp.result in dependent_nodes elif isinstance(comp, computation_building_blocks.Block): return any(x[1] in dependent_nodes for x in comp.locals) or comp.result in dependent_nodes elif isinstance(comp, computation_building_blocks.Tuple): return any(x in dependent_nodes for x in comp) elif isinstance(comp, computation_building_blocks.Selection): return comp.source in dependent_nodes elif isinstance(comp, computation_building_blocks.Call): return comp.function in dependent_nodes or comp.argument in dependent_nodes elif isinstance(comp, computation_building_blocks.Reference): return _is_reference_dependent(comp, symbol_tree) def _is_reference_dependent(comp, symbol_tree): try: payload = symbol_tree.get_payload_with_name(comp.name) except NameError: return False # The postorder traversal ensures that we process any # bindings before we process the reference to those bindings return payload.value in dependent_nodes def _populate_dependent_set(comp, symbol_tree): """Populates `dependent_nodes` with all nodes dependent on `predicate`.""" if predicate(comp): dependent_nodes.add(comp) elif _are_children_in_dependent_set(comp, symbol_tree): dependent_nodes.add(comp) return comp, False symbol_tree = transformation_utils.SymbolTree( transformation_utils.ReferenceCounter) transformation_utils.transform_postorder_with_symbol_bindings( tree, _populate_dependent_set, symbol_tree) return dependent_nodes