예제 #1
0
def visit_preorder(
    tree: building_blocks.ComputationBuildingBlock,
    function: Callable[[building_blocks.ComputationBuildingBlock], None]):
  py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock)

  def _visit(building_block):
    function(building_block)
    return building_block, False

  transformation_utils.transform_preorder(tree, _visit)
예제 #2
0
def compile_local_computation_to_tensorflow(comp):
    """Compiles any fully specified local function to a TensorFlow computation."""
    if comp.is_compiled_computation() or (
            comp.is_call() and comp.function.is_compiled_computation()):
        # These represent the final result of TF generation; no need to transform,
        # so we short-circuit here.
        return comp, False
    local_tf_generator = TensorFlowGenerator()
    return transformation_utils.transform_preorder(
        comp, local_tf_generator.transform)
예제 #3
0
def compile_local_subcomputations_to_tensorflow(
    comp: building_blocks.ComputationBuildingBlock,
) -> building_blocks.ComputationBuildingBlock:
    """Compiles subcomputations to TensorFlow where possible."""
    comp = unpack_compiled_computations(comp)
    local_cache = {}

    def _is_local(comp):
        cached = local_cache.get(comp, None)
        if cached is not None:
            return cached
        if (comp.is_intrinsic() or comp.is_data() or comp.is_placement() or
                type_analysis.contains_federated_types(comp.type_signature)):
            local_cache[comp] = False
            return False
        if (comp.is_compiled_computation()
                and comp.proto.WhichOneof('computation') == 'xla'):
            local_cache[comp] = False
            return False
        for child in comp.children():
            if not _is_local(child):
                local_cache[comp] = False
                return False
        return True

    unbound_ref_map = transformation_utils.get_map_of_unbound_references(comp)

    def _compile_if_local(comp):
        if _is_local(comp) and not unbound_ref_map[comp]:
            return compile_local_computation_to_tensorflow(comp), True
        return comp, False

    # Note: this transformation is preorder so that local subcomputations are not
    # first transformed to TensorFlow if they have a parent local computation
    # which could have instead been transformed into a larger single block of
    # TensorFlow.
    comp, _ = transformation_utils.transform_preorder(comp, _compile_if_local)
    return comp
예제 #4
0
def compile_local_computations_to_tensorflow(comp):
    """Compiles any fully specified local functions to a TensorFlow computation.

  This function walks the AST backing `comp` in a preorder manner, calling out
  to TF-generating functions when it encounters a subcomputation which can be
  represented in TensorFlow. The fact that this function walks preorder is
  extremely important to efficiency of the generated TensorFlow; if we instead
  traversed in a bottom-up fashion, we could potentially generate duplicated
  structures where such duplication is unnecessary.

  Consider for example a computation with structure:

    [TFComp()[0], TFComp()[1]]

  Due to its preorder walk, this function will call out to TF-generating
  utilities with the *entire* structure above; this structure still has enough
  information to detect that TFComp() should be equivalent in both invocations
  (at least, according to TFF's functional specification). If we traversed the
  AST backing `comp` in a bottom-up fashion, we would instead make separate
  calls to TF generation functions, resulting in a structure like:

    [TFComp0(), TFComp1()]

  where the graphs backing TFComp0 and TFComp1 share some graph substructure.
  TFF does not inspect the substructures of the graphs it generates, and would
  simply declare each of the above to be fully distinct invocations, and would
  require that each run when the resulting graph is invoked.

  We provide this function to ensure that callers of TFF's TF-generation
  utilities are usually shielded from such concerns.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` whose local
      computations we wish to compile to TensorFlow.

  Returns:
    A tuple whose first element represents an equivalent computation, but whose
    local computations are represented as TensorFlow graphs. The second element
    of this tuple is a Boolean indicating whether any transforamtion was made.
  """

    non_tf_compiled_comp_types = set()

    def _visit(comp):
        if comp.is_compiled_computation(
        ) and comp.proto.WhichOneof('computation') != 'tensorflow':
            non_tf_compiled_comp_types.add(
                comp.proto.WhichOneof('computation'))

    tree_analysis.visit_postorder(comp, _visit)
    if non_tf_compiled_comp_types:
        raise TypeError(
            'Encountered non-TensorFlow compiled computation types {} '
            'in argument {} to '
            '`compile_local_computations_to_tensorflow`.'.format(
                non_tf_compiled_comp_types, comp.formatted_representation()))

    if comp.is_compiled_computation() or (
            comp.is_call() and comp.function.is_compiled_computation()):
        # These represent the final result of TF generation; no need to transform,
        # so we short-circuit here.
        return comp, False
    local_tf_generator = TensorFlowGenerator()
    return transformation_utils.transform_preorder(
        comp, local_tf_generator.transform)
예제 #5
0
def preprocess_for_tf_parse(comp):
  """Deduplicates called graphs in the AST nested in Selections and Tuples."""
  preprocessor = IntermediateParser()
  comp, _ = transformation_utils.transform_preorder(comp, preprocessor)
  return remove_duplicate_called_graphs(comp)