Esempio n. 1
0
def _split_by_intrinsics_in_top_level_lambda(comp):
    """Splits by the intrinsics in the frist block local in the result of `comp`.

  This function splits `comp` into two computations `before` and `after` the
  called intrinsic or tuple of called intrinsics found as the first local in the
  `building_blocks.Block` returned by the top level lambda; and returns a Python
  tuple representing the pair of `before` and `after` computations.

  Args:
    comp: The `building_blocks.Lambda` to split.

  Returns:
    A pair of `building_blocks.ComputationBuildingBlock`s.

  Raises:
    ValueError: If the first local in the `building_blocks.Block` referenced by
      the top level lambda is not a called intrincs or a
      `building_blocks.Struct` of called intrinsics.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(comp.result, building_blocks.Block)
    tree_analysis.check_has_unique_names(comp)

    name_generator = building_block_factory.unique_name_generator(comp)

    name, first_local = comp.result.locals[0]
    if building_block_analysis.is_called_intrinsic(first_local):
        result = first_local.argument
    elif first_local.is_struct():
        elements = []
        for element in first_local:
            if not building_block_analysis.is_called_intrinsic(element):
                raise ValueError(
                    'Expected all the elements of the `building_blocks.Struct` to be '
                    'called intrinsics, but found: \n{}'.format(element))
            elements.append(element.argument)
        result = building_blocks.Struct(elements)
    else:
        raise ValueError(
            'Expected either a called intrinsic or a `building_blocks.Struct` of '
            'called intrinsics, but found: \n{}'.format(first_local))

    before = building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
                                    result)

    ref_name = next(name_generator)
    ref_type = computation_types.StructType(
        (comp.parameter_type, first_local.type_signature))
    ref = building_blocks.Reference(ref_name, ref_type)
    sel_after_arg_1 = building_blocks.Selection(ref, index=0)
    sel_after_arg_2 = building_blocks.Selection(ref, index=1)

    variables = comp.result.locals
    variables[0] = (name, sel_after_arg_2)
    variables.insert(0, (comp.parameter_name, sel_after_arg_1))
    block = building_blocks.Block(variables, comp.result.result)
    after = building_blocks.Lambda(ref.name, ref.type_signature, block)
    return before, after
Esempio n. 2
0
def contains_called_intrinsic(tree, uri=None):
    """Tests if `tree` contains a called intrinsic for the given `uri`.

  Args:
    tree: A `building_blocks.ComputationBuildingBlock`.
    uri: An optional URI or list of URIs; the same as what is accepted by
      `building_block_analysis.is_called_intrinsic`.

  Returns:
    `True` if there is a called intrinsic in `tree` for the given `uri`,
    otherwise `False`.
  """
    predicate = lambda x: building_block_analysis.is_called_intrinsic(x, uri)
    return count(tree, predicate) > 0
Esempio n. 3
0
 def _predicate(comp):
     return building_block_analysis.is_called_intrinsic(comp, uri)
Esempio n. 4
0
 def _should_transform(comp):
     return building_block_analysis.is_called_intrinsic(comp, uri)
Esempio n. 5
0
def _group_by_intrinsics_in_top_level_lambda(comp):
    """Groups the intrinsics in the frist block local in the result of `comp`.

  This transformation creates an AST by replacing the tuple of called intrinsics
  found as the first local in the `building_blocks.Block` returned by the top
  level lambda with two new computations. The first computation is a tuple of
  tuples of called intrinsics, representing the original tuple of called
  intrinscis grouped by URI. The second computation is a tuple of selection from
  the first computations, representing original tuple of called intrinsics.

  It is necessary to group intrinsics before it is possible to merge them.

  Args:
    comp: The `building_blocks.Lambda` to transform.

  Returns:
    A `building_blocks.Lamda` that returns a `building_blocks.Block`, the first
    local variables of the retunred `building_blocks.Block` will be a tuple of
    tuples of called intrinsics representing the original tuple of called
    intrinscis grouped by URI.

  Raises:
    ValueError: If the first local in the `building_blocks.Block` referenced by
      the top level lambda is not a `building_blocks.Struct` of called
      intrinsics.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(comp.result, building_blocks.Block)
    tree_analysis.check_has_unique_names(comp)

    name_generator = building_block_factory.unique_name_generator(comp)

    name, first_local = comp.result.locals[0]
    py_typecheck.check_type(first_local, building_blocks.Struct)
    for element in first_local:
        if not building_block_analysis.is_called_intrinsic(element):
            raise ValueError(
                'Expected all the elements of the `building_blocks.Struct` to be '
                'called intrinsics, but found: \n{}'.format(element))

    # Create collections of data describing how to pack and unpack the intrinsics
    # into groups by their URI.
    #
    # packed_keys is a list of unique URI ordered by occurrence in the original
    #   tuple of called intrinsics.
    # packed_groups is a `collections.OrderedDict` where each key is a URI to
    #   group by and each value is a list of intrinsics with that URI.
    # packed_indexes is a list of tuples where each tuple contains two indexes:
    #   the first index in the tuple is the index of the group that the intrinsic
    #   was packed into; the second index in the tuple is the index of the
    #   intrinsic in that group that the intrinsic was packed into; the index of
    #   the tuple in packed_indexes corresponds to the index of the intrinsic in
    #   the list of intrinsics that are beging grouped. Therefore, packed_indexes
    #   represents an implicit mapping of packed indexes, keyed by unpacked index.
    packed_keys = []
    for called_intrinsic in first_local:
        uri = called_intrinsic.function.uri
        if uri not in packed_keys:
            packed_keys.append(uri)
    # If there are no duplicates, return early.
    if len(packed_keys) == len(first_local):
        return comp, False
    packed_groups = collections.OrderedDict([(x, []) for x in packed_keys])
    packed_indexes = []
    for called_intrinsic in first_local:
        packed_group = packed_groups[called_intrinsic.function.uri]
        packed_group.append(called_intrinsic)
        packed_indexes.append((
            packed_keys.index(called_intrinsic.function.uri),
            len(packed_group) - 1,
        ))

    packed_elements = []
    for called_intrinsics in packed_groups.values():
        if len(called_intrinsics) > 1:
            element = building_blocks.Struct(called_intrinsics)
        else:
            element = called_intrinsics[0]
        packed_elements.append(element)
    packed_comp = building_blocks.Struct(packed_elements)

    packed_ref_name = next(name_generator)
    packed_ref_type = computation_types.to_type(packed_comp.type_signature)
    packed_ref = building_blocks.Reference(packed_ref_name, packed_ref_type)

    unpacked_elements = []
    for indexes in packed_indexes:
        group_index = indexes[0]
        sel = building_blocks.Selection(packed_ref, index=group_index)
        uri = packed_keys[group_index]
        called_intrinsics = packed_groups[uri]
        if len(called_intrinsics) > 1:
            intrinsic_index = indexes[1]
            sel = building_blocks.Selection(sel, index=intrinsic_index)
        unpacked_elements.append(sel)
    unpacked_comp = building_blocks.Struct(unpacked_elements)

    variables = comp.result.locals
    variables[0] = (name, unpacked_comp)
    variables.insert(0, (packed_ref_name, packed_comp))
    block = building_blocks.Block(variables, comp.result.result)
    fn = building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
                                block)
    return fn, True
Esempio n. 6
0
 def _update(comp):
     if building_block_analysis.is_called_intrinsic(comp, uri):
         intrinsics.append(comp)
     return comp, False
Esempio n. 7
0
 def _update(comp):
     if building_block_analysis.is_called_intrinsic(comp, uri):
         existing_uri.add(comp.function.uri)
     return comp, False
Esempio n. 8
0
 def _find_federated_aggregate(comp):
     if building_block_analysis.is_called_intrinsic(
             comp, intrinsic_defs.FEDERATED_AGGREGATE.uri):
         federated_agg.append(comp)
     return comp, False
Esempio n. 9
0
 def _find_federated_broadcast(comp):
     if building_block_analysis.is_called_intrinsic(
             comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
         federated_broadcast.append(comp)
     return comp, False
Esempio n. 10
0
def _count_broadcasts(comp):
    """Returns the number of called federated broadcasts found in `comp`."""
    uri = intrinsic_defs.FEDERATED_BROADCAST.uri
    predicate = lambda x: building_block_analysis.is_called_intrinsic(x, uri)
    return tree_analysis.count(comp, predicate)