예제 #1
0
def dedupe_and_merge_tuple_intrinsics(comp, uri):
  r"""Merges tuples of called intrinsics into one called intrinsic."""

  # TODO(b/147359721): The application of the function below is a workaround to
  # a known pattern preventing TFF from deduplicating, effectively because tree
  # equality won't determine that [a, a][0] and [a, a][1] are actually the same
  # thing. A fuller fix is planned, but requires increasing the invariants
  # respected further up the TFF compilation pipelines. That is, in order to
  # reason about sufficiency of our ability to detect duplicates at this layer,
  # we would very much prefer to be operating in the subset of TFF effectively
  # representing local computation.

  def _remove_selection_from_block_holding_tuple(comp):
    """Reduces selection from a block holding a tuple."""
    if (comp.is_selection() and comp.source.is_block() and
        comp.source.result.is_struct()):
      if comp.index is None:
        names = [
            x[0]
            for x in anonymous_tuple.iter_elements(comp.source.type_signature)
        ]
        index = names.index(comp.name)
      else:
        index = comp.index
      return building_blocks.Block(comp.source.locals,
                                   comp.source.result[index]), True
    return comp, False

  comp, _ = transformation_utils.transform_postorder(
      comp, _remove_selection_from_block_holding_tuple)
  transform_spec = tree_transformations.MergeTupleIntrinsics(comp, uri)
  dedupe_and_merger = RemoveDuplicatesAndApplyTransform(comp, transform_spec)
  return transformation_utils.transform_postorder(comp,
                                                  dedupe_and_merger.transform)
예제 #2
0
def dedupe_and_merge_tuple_intrinsics(comp, uri):
    r"""Merges tuples of called intrinsics into one called intrinsic."""
    transform_spec = tree_transformations.MergeTupleIntrinsics(comp, uri)
    dedupe_and_merger = RemoveDuplicatesAndApplyTransform(comp, transform_spec)

    def _transform(comp):
        if dedupe_and_merger.should_transform(comp):
            return dedupe_and_merger.transform(comp)
        return comp, False

    return transformation_utils.transform_postorder(comp, _transform)