def visit_preorder( tree: building_blocks.ComputationBuildingBlock, function: Callable[[building_blocks.ComputationBuildingBlock], None]): py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock) def _visit(building_block): function(building_block) return building_block, False transformation_utils.transform_preorder(tree, _visit)
def compile_local_computation_to_tensorflow(comp): """Compiles any fully specified local function to a TensorFlow computation.""" if comp.is_compiled_computation() or ( comp.is_call() and comp.function.is_compiled_computation()): # These represent the final result of TF generation; no need to transform, # so we short-circuit here. return comp, False local_tf_generator = TensorFlowGenerator() return transformation_utils.transform_preorder( comp, local_tf_generator.transform)
def compile_local_subcomputations_to_tensorflow( comp: building_blocks.ComputationBuildingBlock, ) -> building_blocks.ComputationBuildingBlock: """Compiles subcomputations to TensorFlow where possible.""" comp = unpack_compiled_computations(comp) local_cache = {} def _is_local(comp): cached = local_cache.get(comp, None) if cached is not None: return cached if (comp.is_intrinsic() or comp.is_data() or comp.is_placement() or type_analysis.contains_federated_types(comp.type_signature)): local_cache[comp] = False return False if (comp.is_compiled_computation() and comp.proto.WhichOneof('computation') == 'xla'): local_cache[comp] = False return False for child in comp.children(): if not _is_local(child): local_cache[comp] = False return False return True unbound_ref_map = transformation_utils.get_map_of_unbound_references(comp) def _compile_if_local(comp): if _is_local(comp) and not unbound_ref_map[comp]: return compile_local_computation_to_tensorflow(comp), True return comp, False # Note: this transformation is preorder so that local subcomputations are not # first transformed to TensorFlow if they have a parent local computation # which could have instead been transformed into a larger single block of # TensorFlow. comp, _ = transformation_utils.transform_preorder(comp, _compile_if_local) return comp
def compile_local_computations_to_tensorflow(comp): """Compiles any fully specified local functions to a TensorFlow computation. This function walks the AST backing `comp` in a preorder manner, calling out to TF-generating functions when it encounters a subcomputation which can be represented in TensorFlow. The fact that this function walks preorder is extremely important to efficiency of the generated TensorFlow; if we instead traversed in a bottom-up fashion, we could potentially generate duplicated structures where such duplication is unnecessary. Consider for example a computation with structure: [TFComp()[0], TFComp()[1]] Due to its preorder walk, this function will call out to TF-generating utilities with the *entire* structure above; this structure still has enough information to detect that TFComp() should be equivalent in both invocations (at least, according to TFF's functional specification). If we traversed the AST backing `comp` in a bottom-up fashion, we would instead make separate calls to TF generation functions, resulting in a structure like: [TFComp0(), TFComp1()] where the graphs backing TFComp0 and TFComp1 share some graph substructure. TFF does not inspect the substructures of the graphs it generates, and would simply declare each of the above to be fully distinct invocations, and would require that each run when the resulting graph is invoked. We provide this function to ensure that callers of TFF's TF-generation utilities are usually shielded from such concerns. Args: comp: Instance of `building_blocks.ComputationBuildingBlock` whose local computations we wish to compile to TensorFlow. Returns: A tuple whose first element represents an equivalent computation, but whose local computations are represented as TensorFlow graphs. The second element of this tuple is a Boolean indicating whether any transforamtion was made. """ non_tf_compiled_comp_types = set() def _visit(comp): if comp.is_compiled_computation( ) and comp.proto.WhichOneof('computation') != 'tensorflow': non_tf_compiled_comp_types.add( comp.proto.WhichOneof('computation')) tree_analysis.visit_postorder(comp, _visit) if non_tf_compiled_comp_types: raise TypeError( 'Encountered non-TensorFlow compiled computation types {} ' 'in argument {} to ' '`compile_local_computations_to_tensorflow`.'.format( non_tf_compiled_comp_types, comp.formatted_representation())) if comp.is_compiled_computation() or ( comp.is_call() and comp.function.is_compiled_computation()): # These represent the final result of TF generation; no need to transform, # so we short-circuit here. return comp, False local_tf_generator = TensorFlowGenerator() return transformation_utils.transform_preorder( comp, local_tf_generator.transform)
def preprocess_for_tf_parse(comp): """Deduplicates called graphs in the AST nested in Selections and Tuples.""" preprocessor = IntermediateParser() comp, _ = transformation_utils.transform_preorder(comp, preprocessor) return remove_duplicate_called_graphs(comp)