def _split_by_intrinsics_in_top_level_lambda(comp): """Splits by the intrinsics in the frist block local in the result of `comp`. This function splits `comp` into two computations `before` and `after` the called intrinsic or tuple of called intrinsics found as the first local in the `building_blocks.Block` returned by the top level lambda; and returns a Python tuple representing the pair of `before` and `after` computations. Args: comp: The `building_blocks.Lambda` to split. Returns: A pair of `building_blocks.ComputationBuildingBlock`s. Raises: ValueError: If the first local in the `building_blocks.Block` referenced by the top level lambda is not a called intrincs or a `building_blocks.Struct` of called intrinsics. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(comp.result, building_blocks.Block) tree_analysis.check_has_unique_names(comp) name_generator = building_block_factory.unique_name_generator(comp) name, first_local = comp.result.locals[0] if building_block_analysis.is_called_intrinsic(first_local): result = first_local.argument elif first_local.is_struct(): elements = [] for element in first_local: if not building_block_analysis.is_called_intrinsic(element): raise ValueError( 'Expected all the elements of the `building_blocks.Struct` to be ' 'called intrinsics, but found: \n{}'.format(element)) elements.append(element.argument) result = building_blocks.Struct(elements) else: raise ValueError( 'Expected either a called intrinsic or a `building_blocks.Struct` of ' 'called intrinsics, but found: \n{}'.format(first_local)) before = building_blocks.Lambda(comp.parameter_name, comp.parameter_type, result) ref_name = next(name_generator) ref_type = computation_types.StructType( (comp.parameter_type, first_local.type_signature)) ref = building_blocks.Reference(ref_name, ref_type) sel_after_arg_1 = building_blocks.Selection(ref, index=0) sel_after_arg_2 = building_blocks.Selection(ref, index=1) variables = comp.result.locals variables[0] = (name, sel_after_arg_2) variables.insert(0, (comp.parameter_name, sel_after_arg_1)) block = building_blocks.Block(variables, comp.result.result) after = building_blocks.Lambda(ref.name, ref.type_signature, block) return before, after
def contains_called_intrinsic(tree, uri=None): """Tests if `tree` contains a called intrinsic for the given `uri`. Args: tree: A `building_blocks.ComputationBuildingBlock`. uri: An optional URI or list of URIs; the same as what is accepted by `building_block_analysis.is_called_intrinsic`. Returns: `True` if there is a called intrinsic in `tree` for the given `uri`, otherwise `False`. """ predicate = lambda x: building_block_analysis.is_called_intrinsic(x, uri) return count(tree, predicate) > 0
def _predicate(comp): return building_block_analysis.is_called_intrinsic(comp, uri)
def _should_transform(comp): return building_block_analysis.is_called_intrinsic(comp, uri)
def _group_by_intrinsics_in_top_level_lambda(comp): """Groups the intrinsics in the frist block local in the result of `comp`. This transformation creates an AST by replacing the tuple of called intrinsics found as the first local in the `building_blocks.Block` returned by the top level lambda with two new computations. The first computation is a tuple of tuples of called intrinsics, representing the original tuple of called intrinscis grouped by URI. The second computation is a tuple of selection from the first computations, representing original tuple of called intrinsics. It is necessary to group intrinsics before it is possible to merge them. Args: comp: The `building_blocks.Lambda` to transform. Returns: A `building_blocks.Lamda` that returns a `building_blocks.Block`, the first local variables of the retunred `building_blocks.Block` will be a tuple of tuples of called intrinsics representing the original tuple of called intrinscis grouped by URI. Raises: ValueError: If the first local in the `building_blocks.Block` referenced by the top level lambda is not a `building_blocks.Struct` of called intrinsics. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(comp.result, building_blocks.Block) tree_analysis.check_has_unique_names(comp) name_generator = building_block_factory.unique_name_generator(comp) name, first_local = comp.result.locals[0] py_typecheck.check_type(first_local, building_blocks.Struct) for element in first_local: if not building_block_analysis.is_called_intrinsic(element): raise ValueError( 'Expected all the elements of the `building_blocks.Struct` to be ' 'called intrinsics, but found: \n{}'.format(element)) # Create collections of data describing how to pack and unpack the intrinsics # into groups by their URI. # # packed_keys is a list of unique URI ordered by occurrence in the original # tuple of called intrinsics. # packed_groups is a `collections.OrderedDict` where each key is a URI to # group by and each value is a list of intrinsics with that URI. # packed_indexes is a list of tuples where each tuple contains two indexes: # the first index in the tuple is the index of the group that the intrinsic # was packed into; the second index in the tuple is the index of the # intrinsic in that group that the intrinsic was packed into; the index of # the tuple in packed_indexes corresponds to the index of the intrinsic in # the list of intrinsics that are beging grouped. Therefore, packed_indexes # represents an implicit mapping of packed indexes, keyed by unpacked index. packed_keys = [] for called_intrinsic in first_local: uri = called_intrinsic.function.uri if uri not in packed_keys: packed_keys.append(uri) # If there are no duplicates, return early. if len(packed_keys) == len(first_local): return comp, False packed_groups = collections.OrderedDict([(x, []) for x in packed_keys]) packed_indexes = [] for called_intrinsic in first_local: packed_group = packed_groups[called_intrinsic.function.uri] packed_group.append(called_intrinsic) packed_indexes.append(( packed_keys.index(called_intrinsic.function.uri), len(packed_group) - 1, )) packed_elements = [] for called_intrinsics in packed_groups.values(): if len(called_intrinsics) > 1: element = building_blocks.Struct(called_intrinsics) else: element = called_intrinsics[0] packed_elements.append(element) packed_comp = building_blocks.Struct(packed_elements) packed_ref_name = next(name_generator) packed_ref_type = computation_types.to_type(packed_comp.type_signature) packed_ref = building_blocks.Reference(packed_ref_name, packed_ref_type) unpacked_elements = [] for indexes in packed_indexes: group_index = indexes[0] sel = building_blocks.Selection(packed_ref, index=group_index) uri = packed_keys[group_index] called_intrinsics = packed_groups[uri] if len(called_intrinsics) > 1: intrinsic_index = indexes[1] sel = building_blocks.Selection(sel, index=intrinsic_index) unpacked_elements.append(sel) unpacked_comp = building_blocks.Struct(unpacked_elements) variables = comp.result.locals variables[0] = (name, unpacked_comp) variables.insert(0, (packed_ref_name, packed_comp)) block = building_blocks.Block(variables, comp.result.result) fn = building_blocks.Lambda(comp.parameter_name, comp.parameter_type, block) return fn, True
def _update(comp): if building_block_analysis.is_called_intrinsic(comp, uri): intrinsics.append(comp) return comp, False
def _update(comp): if building_block_analysis.is_called_intrinsic(comp, uri): existing_uri.add(comp.function.uri) return comp, False
def _find_federated_aggregate(comp): if building_block_analysis.is_called_intrinsic( comp, intrinsic_defs.FEDERATED_AGGREGATE.uri): federated_agg.append(comp) return comp, False
def _find_federated_broadcast(comp): if building_block_analysis.is_called_intrinsic( comp, intrinsic_defs.FEDERATED_BROADCAST.uri): federated_broadcast.append(comp) return comp, False
def _count_broadcasts(comp): """Returns the number of called federated broadcasts found in `comp`.""" uri = intrinsic_defs.FEDERATED_BROADCAST.uri predicate = lambda x: building_block_analysis.is_called_intrinsic(x, uri) return tree_analysis.count(comp, predicate)