def should_transform(self, comp): if not (type_analysis.is_tensorflow_compatible_type( comp.type_signature) or (comp.type_signature.is_function() and type_analysis.is_tensorflow_compatible_type( comp.type_signature.parameter) and type_analysis.is_tensorflow_compatible_type( comp.type_signature.result))): return False elif comp.is_compiled_computation() or ( comp.is_call() and comp.function.is_compiled_computation()): # These represent the final result of TF generation; no need to transform. return False unbound_refs = transformation_utils.get_map_of_unbound_references( comp)[comp] if unbound_refs: # We cannot represent these captures without further information. return False if tree_analysis.contains_types( comp, building_blocks.Intrinsic) or tree_analysis.contains_types( comp, building_blocks.Data): return False return True
def remove_duplicate_called_graphs(comp): """Deduplicates called graphs for a subset of TFF AST constructs. Args: comp: Instance of `building_blocks.ComputationBuildingBlock` whose called graphs we wish to deduplicate, according to `tree_analysis.trees_equal`. For `comp` to be eligible here, it must be either a lambda itself whose body contains no lambdas or blocks, or another computation containing no lambdas or blocks. This restriction is necessary because `remove_duplicate_called_graphs` makes no effort to ensure that it is not pulling references out of their defining scope, except for the case where `comp` is a lambda itself. This function exits early and logs a warning if this assumption is violated. Additionally, `comp` must contain only computations which can be represented in TensorFlow, IE, satisfy the type restriction in `type_analysis.is_tensorflow_compatible_type`. Returns: Either a called instance of `building_blocks.CompiledComputation` or a `building_blocks.CompiledComputation` itself, depending on whether `comp` is of non-functional or functional type respectively. Additionally, returns a boolean to match the `transformation_utils.TransformSpec` pattern. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) tree_analysis.check_has_unique_names(comp) name_generator = building_block_factory.unique_name_generator(comp) if comp.is_lambda(): comp_to_check = comp.result else: comp_to_check = comp if tree_analysis.contains_types(comp_to_check, ( building_blocks.Block, building_blocks.Lambda, )): logging.warning( 'The preprocessors have failed to remove called lambdas ' 'and blocks; falling back to less efficient, but ' 'guaranteed, TensorFlow generation with computation %s.', comp) return comp, False leaf_called_graphs = [] def _pack_called_graphs_into_block(inner_comp): """Packs deduplicated bindings to called graphs in `leaf_called_graphs`.""" if inner_comp.is_call() and inner_comp.function.is_compiled_computation(): for (name, x) in leaf_called_graphs: if tree_analysis.trees_equal(x, inner_comp): return building_blocks.Reference(name, inner_comp.type_signature), True new_name = next(name_generator) leaf_called_graphs.append((new_name, inner_comp)) return building_blocks.Reference(new_name, inner_comp.type_signature), True return inner_comp, False if comp.is_lambda(): transformed_result, _ = transformation_utils.transform_postorder( comp.result, _pack_called_graphs_into_block) packed_into_block = building_blocks.Block(leaf_called_graphs, transformed_result) parsed, _ = create_tensorflow_representing_block(packed_into_block) tff_func = building_blocks.Lambda(comp.parameter_name, comp.parameter_type, parsed) tf_parser_callable = tree_to_cc_transformations.TFParser() comp, _ = tree_transformations.insert_called_tf_identity_at_leaves(tff_func) tf_generated, _ = transformation_utils.transform_postorder( comp, tf_parser_callable) else: transformed_result, _ = transformation_utils.transform_postorder( comp, _pack_called_graphs_into_block) packed_into_block = building_blocks.Block(leaf_called_graphs, transformed_result) tf_generated, _ = create_tensorflow_representing_block(packed_into_block) return tf_generated, True