Esempio n. 1
0
def prepare_for_rebinding(comp):
    """Prepares `comp` for extracting rebound variables.

  Currently, this means replacing all called lambdas and inlining all blocks.
  This does not necessarly guarantee that the resulting computation has no
  called lambdas, it merely reduces a level of indirection here. This reduction
  has proved sufficient for identifying variables which are about to be rebound
  in the top-level lambda, necessarily when compiler components factor work out
  from a single function into multiple functions. Since this function makes no
  guarantees about sufficiency, it is the responsibility of the caller to
  ensure that no unbound variables are introduced during the rebinding.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` from which all
      occurrences of a given variable need to be extracted and rebound.

  Returns:
    Another instance of `building_blocks.ComputationBuildingBlock` which has
    had all called lambdas replaced by blocks, all blocks inlined and all
    selections from tuples collapsed.
  """
    # TODO(b/146430051): Follow up here and consider removing or enforcing more
    # strict output invariants when `remove_called_lambdas_and_blocks` is moved
    # in here.
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
    comp, _ = tree_transformations.uniquify_reference_names(comp)
    comp, _ = tree_transformations.replace_called_lambda_with_block(comp)
    block_inliner = tree_transformations.InlineBlock(comp)
    selection_replacer = tree_transformations.ReplaceSelectionFromTuple()
    transforms = [block_inliner, selection_replacer]
    symbol_tree = transformation_utils.SymbolTree(
        transformation_utils.ReferenceCounter)

    def _transform_fn(comp, symbol_tree):
        """Transform function chaining inlining and collapsing selections."""
        modified = False
        for transform in transforms:
            if transform.global_transform:
                comp, transform_modified = transform.transform(
                    comp, symbol_tree)
            else:
                comp, transform_modified = transform.transform(comp)
            modified = modified or transform_modified
        return comp, modified

    return transformation_utils.transform_postorder_with_symbol_bindings(
        comp, _transform_fn, symbol_tree)
Esempio n. 2
0
  def _check_calls_are_concrete(comp):
    """Encodes condition for completeness of direct extraction of calls.

    After checking this condition, all functions which are semantically called
    (IE, functions which will be invoked eventually by running the computation)
    are called directly, and we can simply extract them by pattern-matching on
    `building_blocks.Call`.

    Args:
      comp: Instance of `building_blocks.ComputationBuildingBlock` to check for
        condition that functional argument of `Call` constructs contains only
        the enumeration in the top-level docstring.

    Raises:
      ValueError: If `comp` fails this condition.
    """
    symbol_tree = transformation_utils.SymbolTree(
        transformation_utils.ReferenceCounter)

    def _check_for_call_arguments(comp_to_check, symbol_tree):
      if not comp_to_check.is_call():
        return comp_to_check, False
      functional_arg = comp_to_check.function
      if functional_arg.is_compiled_computation(
      ) or functional_arg.is_intrinsic():
        return comp_to_check, False
      elif functional_arg.is_lambda():
        if type_analysis.contains(functional_arg.type_signature.result,
                                  lambda x: x.is_function()):
          raise ValueError('Called higher-order functions are disallowed in '
                           'transforming to call-dominant form, as they may '
                           'break the reliance on pattern-matching to extract '
                           'called intrinsics. Encountered a call to the'
                           'lambda {l} with type signature {t}.'.format(
                               l=functional_arg,
                               t=functional_arg.type_signature))
        return comp_to_check, False
      elif functional_arg.is_reference():
        # This case and the following handle the possibility that a lambda
        # declares a functional parameter, and this parameter is invoked in its
        # body.
        payload = symbol_tree.get_payload_with_name(functional_arg.name)
        if payload is None:
          return comp, False
        if payload.value is not None:
          raise ValueError('Called references which are not bound to lambda '
                           'parameters are disallowed in transforming to '
                           'call-dominant form, as they may break the reliance '
                           'on pattern-matching to extract called intrinsics. '
                           'Encountered a call to the reference {r}, which is '
                           'bound to the value {v} in this computation.'.format(
                               r=functional_arg, v=payload.value))
      elif functional_arg.is_selection():
        concrete_source = functional_arg.source
        while concrete_source.is_selection():
          concrete_source = concrete_source.source
        if concrete_source.is_reference():
          payload = symbol_tree.get_payload_with_name(concrete_source.name)
          if payload is None:
            return comp, False
          if payload.value is not None:
            raise ValueError('Called selections from references which are not '
                             'bound to lambda parameters are disallowed in '
                             'transforming to call-dominant form, as they may '
                             'break the reliance on pattern-matching to '
                             'extract called intrinsics. Encountered a call to '
                             'the reference {r}, which is bound to the value '
                             '{v} in this computation.'.format(
                                 r=functional_arg, v=payload.value))
          return comp, False
        else:
          raise ValueError('Called selections are only permitted in '
                           'transforming to call-comiunant form the case that '
                           'they select from lambda parameters; encountered a '
                           'call to selection {s}.'.format(s=functional_arg))
      else:
        raise ValueError('During transformation to call-dominant form, we rely '
                         'on the assumption that all called functions are '
                         'either: compiled computations; intrinsics; lambdas '
                         'with nonfuntional return types; or selections from '
                         'lambda parameters. Encountered the called function '
                         '{f} of type {t}.'.format(
                             f=functional_arg, t=type(functional_arg)))

    transformation_utils.transform_postorder_with_symbol_bindings(
        comp, _check_for_call_arguments, symbol_tree)
Esempio n. 3
0
def extract_nodes_consuming(tree, predicate):
    """Returns the set of AST nodes which consume nodes matching `predicate`.

  Notice we adopt the convention that a node which itself satisfies the
  predicate is in this set.

  Args:
    tree: Instance of `building_blocks.ComputationBuildingBlock` to view as an
      abstract syntax tree, and construct the set of nodes in this tree having a
      dependency on nodes matching `predicate`; that is, the set of nodes whose
      value depends on evaluating nodes matching `predicate`.
    predicate: One-arg callable, accepting arguments of type
      `building_blocks.ComputationBuildingBlock` and returning a `bool`
      indicating match or mismatch with the desired pattern.

  Returns:
    A `set` of `building_blocks.ComputationBuildingBlock` instances
    representing the nodes in `tree` dependent on nodes matching `predicate`.
  """
    py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock)
    py_typecheck.check_callable(predicate)

    class _NodeSet:
        def __init__(self):
            self.mapping = {}

        def add(self, comp):
            self.mapping[id(comp)] = comp

        def to_set(self):
            return set(self.mapping.values())

    dependent_nodes = _NodeSet()

    def _are_children_in_dependent_set(comp, symbol_tree):
        """Checks if the dependencies of `comp` are present in `dependent_nodes`."""
        if (comp.is_intrinsic() or comp.is_data() or comp.is_placement()
                or comp.is_compiled_computation()):
            return False
        elif comp.is_lambda():
            return id(comp.result) in dependent_nodes.mapping
        elif comp.is_block():
            return any(
                id(x[1]) in dependent_nodes.mapping
                for x in comp.locals) or id(
                    comp.result) in dependent_nodes.mapping
        elif comp.is_struct():
            return any(id(x) in dependent_nodes.mapping for x in comp)
        elif comp.is_selection():
            return id(comp.source) in dependent_nodes.mapping
        elif comp.is_call():
            return id(comp.function) in dependent_nodes.mapping or id(
                comp.argument) in dependent_nodes.mapping
        elif comp.is_reference():
            return _is_reference_dependent(comp, symbol_tree)

    def _is_reference_dependent(comp, symbol_tree):
        payload = symbol_tree.get_payload_with_name(comp.name)
        if payload is None:
            return False
        # The postorder traversal ensures that we process any
        # bindings before we process the reference to those bindings
        return id(payload.value) in dependent_nodes.mapping

    def _populate_dependent_set(comp, symbol_tree):
        """Populates `dependent_nodes` with all nodes dependent on `predicate`."""
        if predicate(comp):
            dependent_nodes.add(comp)
        elif _are_children_in_dependent_set(comp, symbol_tree):
            dependent_nodes.add(comp)
        return comp, False

    symbol_tree = transformation_utils.SymbolTree(
        transformation_utils.ReferenceCounter)
    transformation_utils.transform_postorder_with_symbol_bindings(
        tree, _populate_dependent_set, symbol_tree)
    return dependent_nodes.to_set()
Esempio n. 4
0
def extract_nodes_consuming(tree, predicate):
    """Returns the set of AST nodes which consume nodes matching `predicate`.

  Notice we adopt the convention that a node which itself satisfies the
  predicate is in this set.

  Args:
    tree: Instance of `building_blocks.ComputationBuildingBlock` to view as an
      abstract syntax tree, and construct the set of nodes in this tree having a
      dependency on nodes matching `predicate`; that is, the set of nodes whose
      value depends on evaluating nodes matching `predicate`.
    predicate: One-arg callable, accepting arguments of type
      `building_blocks.ComputationBuildingBlock` and returning a `bool`
      indicating match or mismatch with the desired pattern.

  Returns:
    A `set` of `building_blocks.ComputationBuildingBlock` instances
    representing the nodes in `tree` dependent on nodes matching `predicate`.
  """
    py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock)
    py_typecheck.check_callable(predicate)
    dependent_nodes = set()

    def _are_children_in_dependent_set(comp, symbol_tree):
        """Checks if the dependencies of `comp` are present in `dependent_nodes`."""
        if isinstance(
                comp,
            (building_blocks.Intrinsic, building_blocks.Data,
             building_blocks.Placement, building_blocks.CompiledComputation)):
            return False
        elif isinstance(comp, building_blocks.Lambda):
            return comp.result in dependent_nodes
        elif isinstance(comp, building_blocks.Block):
            return any(x[1] in dependent_nodes
                       for x in comp.locals) or comp.result in dependent_nodes
        elif isinstance(comp, building_blocks.Tuple):
            return any(x in dependent_nodes for x in comp)
        elif isinstance(comp, building_blocks.Selection):
            return comp.source in dependent_nodes
        elif isinstance(comp, building_blocks.Call):
            return comp.function in dependent_nodes or comp.argument in dependent_nodes
        elif isinstance(comp, building_blocks.Reference):
            return _is_reference_dependent(comp, symbol_tree)

    def _is_reference_dependent(comp, symbol_tree):
        payload = symbol_tree.get_payload_with_name(comp.name)
        if payload is None:
            return False
        # The postorder traversal ensures that we process any
        # bindings before we process the reference to those bindings
        return payload.value in dependent_nodes

    def _populate_dependent_set(comp, symbol_tree):
        """Populates `dependent_nodes` with all nodes dependent on `predicate`."""
        if predicate(comp):
            dependent_nodes.add(comp)
        elif _are_children_in_dependent_set(comp, symbol_tree):
            dependent_nodes.add(comp)
        return comp, False

    symbol_tree = transformation_utils.SymbolTree(
        transformation_utils.ReferenceCounter)
    transformation_utils.transform_postorder_with_symbol_bindings(
        tree, _populate_dependent_set, symbol_tree)
    return dependent_nodes
Esempio n. 5
0
def remove_lambdas_and_blocks(comp):
  """Removes any called lambdas and blocks from `comp`.

  This function will rename all the variables in `comp` in a single walk of the
  AST, then replace called lambdas with blocks in another walk, since this
  transformation interacts with scope in delicate ways. It will chain inlining
  the blocks and collapsing the selection-from-tuple pattern together into a
  final pass.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` from which we
      want to remove called lambdas and blocks.

  Returns:
    A transformed version of `comp` which has no called lambdas or blocks, and
    no extraneous selections from tuples.
  """
  py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)

  # TODO(b/146904968): In general, any bounded number of passes of these
  # transforms as currently implemented is insufficient in order to satisfy
  # the purpose of this function. Filing a new bug to followup if this becomes a
  # pressing issue.
  modified = False
  for fn in [
      transformations.remove_unused_block_locals,
      transformations.inline_selections_from_tuple,
      transformations.replace_called_lambda_with_block,
  ] * 2:
    comp, inner_modified = fn(comp)
    modified = inner_modified or modified
  for fn in [
      transformations.remove_unused_block_locals,
      transformations.uniquify_reference_names,
  ]:
    comp, inner_modified = fn(comp)
    modified = inner_modified or modified

  block_inliner = transformations.InlineBlock(comp)
  selection_replacer = transformations.ReplaceSelectionFromTuple()
  transforms = [block_inliner, selection_replacer]

  def _transform_fn(comp, symbol_tree):
    """Transform function chaining inlining and collapsing selections.

    This function is inlined here as opposed to factored out and parameterized
    by the transforms to apply, due to the delicacy of chaining transformations
    which rely on state. These transformations should be safe if they appear
    first in the list of transforms, but due to the difficulty of reasoning
    about the invariants the transforms can rely on in this setting, there is
    no function exposed which hoists out the internal logic.

    Args:
      comp: Instance of `building_blocks.ComputationBuildingBlock` we wish to
        check for inlining and collapsing of selections.
      symbol_tree: Instance of `building_blocks.SymbolTree` defining the
        bindings available to `comp`.

    Returns:
      A transformed version of `comp`.
    """
    modified = False
    for transform in transforms:
      if transform.global_transform:
        comp, transform_modified = transform.transform(comp, symbol_tree)
      else:
        comp, transform_modified = transform.transform(comp)
      modified = modified or transform_modified
    return comp, modified

  symbol_tree = transformation_utils.SymbolTree(
      transformation_utils.ReferenceCounter)
  transformed_comp, inner_modified = transformation_utils.transform_postorder_with_symbol_bindings(
      comp, _transform_fn, symbol_tree)
  modified = modified or inner_modified
  return transformed_comp, modified
def uniquify_reference_names(comp, name_generator=None):
  """Replaces all the bound reference names in `comp` with unique names.

  Notice that `uniquify_reference_names` simply leaves alone any reference
  which is unbound under `comp`.

  Args:
    comp: The computation building block in which to perform the replacements.
    name_generator: An optional generator to use for creating unique names. If
      `name_generator` is not None, all existing bindings will be replaced.

  Returns:
    Returns a transformed version of comp inside of which all variable names
    are guaranteed to be unique, and are guaranteed to not mask any unbound
    names referenced in the body of `comp`.
  """
  py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
  # Passing `comp` to `unique_name_generator` here will ensure that the
  # generated names conflict with neither bindings in `comp` nor unbound
  # references in `comp`.
  if name_generator is None:
    name_generator = building_block_factory.unique_name_generator(comp)
    rename_all = False
  else:
    # If a `name_generator` was passed in, all bindings must be renamed since
    # we need to avoid duplication with an outer scope.
    rename_all = True
  used_names = set()

  class _RenameNode(transformation_utils.BoundVariableTracker):
    """transformation_utils.SymbolTree node for renaming References in ASTs."""

    def __init__(self, name, value):
      super().__init__(name, value)
      py_typecheck.check_type(name, str)
      if rename_all or name in used_names:
        self.new_name = next(name_generator)
      else:
        self.new_name = name
      used_names.add(self.new_name)

    def __str__(self):
      return 'Value: {}, name: {}, new_name: {}'.format(self.value, self.name,
                                                        self.new_name)

  def _transform(comp, context_tree):
    """Renames References in `comp` to unique names."""
    if comp.is_reference():
      payload = context_tree.get_payload_with_name(comp.name)
      if payload is None:
        return comp, False
      new_name = payload.new_name
      if new_name is comp.name:
        return comp, False
      return building_blocks.Reference(new_name, comp.type_signature,
                                       comp.context), True
    elif comp.is_block():
      new_locals = []
      modified = False
      for name, val in comp.locals:
        context_tree.walk_down_one_variable_binding()
        new_name = context_tree.get_payload_with_name(name).new_name
        modified = modified or (new_name is not name)
        new_locals.append((new_name, val))
      return building_blocks.Block(new_locals, comp.result), modified
    elif comp.is_lambda():
      if comp.parameter_type is None:
        return comp, False
      context_tree.walk_down_one_variable_binding()
      new_name = context_tree.get_payload_with_name(
          comp.parameter_name).new_name
      if new_name is comp.parameter_name:
        return comp, False
      return building_blocks.Lambda(new_name, comp.parameter_type,
                                    comp.result), True
    return comp, False

  symbol_tree = transformation_utils.SymbolTree(_RenameNode)
  return transformation_utils.transform_postorder_with_symbol_bindings(
      comp, _transform, symbol_tree)