def _generate_simple_tensorflow(comp): tf_parser_callable = tree_to_cc_transformations.TFParser() comp, _ = tree_transformations.insert_called_tf_identity_at_leaves( comp) comp, _ = transformation_utils.transform_postorder( comp, tf_parser_callable) return comp
def _generate_simple_tensorflow(comp): """Naively generates TensorFlow to represent `comp`.""" tf_parser_callable = tree_to_cc_transformations.TFParser() comp, _ = tree_transformations.insert_called_tf_identity_at_leaves(comp) comp, _ = transformation_utils.transform_postorder(comp, tf_parser_callable) return comp
def parse_tff_to_tf(comp): comp, _ = tree_transformations.insert_called_tf_identity_at_leaves(comp) parser_callable = tree_to_cc_transformations.TFParser() comp, _ = tree_transformations.replace_called_lambda_with_block(comp) comp, _ = tree_transformations.inline_block_locals(comp) comp, _ = tree_transformations.replace_selection_from_tuple_with_element(comp) new_comp, transformed = transformation_utils.transform_postorder( comp, parser_callable) return new_comp, transformed
def transform(self, local_function): if not self.should_transform(local_function): return local_function, False refs_removed, _ = remove_called_lambdas_and_blocks(local_function) parsed_to_tf, _ = remove_duplicate_called_graphs(refs_removed) if parsed_to_tf.is_compiled_computation() or ( parsed_to_tf.is_call() and parsed_to_tf.function.is_compiled_computation()): return parsed_to_tf, True # TODO(b/146430051): We should only end up in this case if # `remove_called_lambdas_and_blocks` above is in its failure mode, IE, # failing to resolve references due to too-deep indirection; we should # remove this extra case and simply raise if we fail here when we fix the # attached bug. called_graphs_inserted, _ = tree_transformations.insert_called_tf_identity_at_leaves( parsed_to_tf) compiled_comp, _ = transformation_utils.transform_postorder( called_graphs_inserted, self._naive_tf_parser) return compiled_comp, True
def remove_duplicate_called_graphs(comp): """Deduplicates called graphs for a subset of TFF AST constructs. Args: comp: Instance of `building_blocks.ComputationBuildingBlock` whose called graphs we wish to deduplicate, according to `tree_analysis.trees_equal`. For `comp` to be eligible here, it must be either a lambda itself whose body contains no lambdas or blocks, or another computation containing no lambdas or blocks. This restriction is necessary because `remove_duplicate_called_graphs` makes no effort to ensure that it is not pulling references out of their defining scope, except for the case where `comp` is a lambda itself. This function exits early and logs a warning if this assumption is violated. Additionally, `comp` must contain only computations which can be represented in TensorFlow, IE, satisfy the type restriction in `type_analysis.is_tensorflow_compatible_type`. Returns: Either a called instance of `building_blocks.CompiledComputation` or a `building_blocks.CompiledComputation` itself, depending on whether `comp` is of non-functional or functional type respectively. Additionally, returns a boolean to match the `transformation_utils.TransformSpec` pattern. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) tree_analysis.check_has_unique_names(comp) name_generator = building_block_factory.unique_name_generator(comp) if comp.is_lambda(): comp_to_check = comp.result else: comp_to_check = comp if tree_analysis.contains_types(comp_to_check, ( building_blocks.Block, building_blocks.Lambda, )): logging.warning( 'The preprocessors have failed to remove called lambdas ' 'and blocks; falling back to less efficient, but ' 'guaranteed, TensorFlow generation with computation %s.', comp) return comp, False leaf_called_graphs = [] def _pack_called_graphs_into_block(inner_comp): """Packs deduplicated bindings to called graphs in `leaf_called_graphs`.""" if inner_comp.is_call() and inner_comp.function.is_compiled_computation(): for (name, x) in leaf_called_graphs: if tree_analysis.trees_equal(x, inner_comp): return building_blocks.Reference(name, inner_comp.type_signature), True new_name = next(name_generator) leaf_called_graphs.append((new_name, inner_comp)) return building_blocks.Reference(new_name, inner_comp.type_signature), True return inner_comp, False if comp.is_lambda(): transformed_result, _ = transformation_utils.transform_postorder( comp.result, _pack_called_graphs_into_block) packed_into_block = building_blocks.Block(leaf_called_graphs, transformed_result) parsed, _ = create_tensorflow_representing_block(packed_into_block) tff_func = building_blocks.Lambda(comp.parameter_name, comp.parameter_type, parsed) tf_parser_callable = tree_to_cc_transformations.TFParser() comp, _ = tree_transformations.insert_called_tf_identity_at_leaves(tff_func) tf_generated, _ = transformation_utils.transform_postorder( comp, tf_parser_callable) else: transformed_result, _ = transformation_utils.transform_postorder( comp, _pack_called_graphs_into_block) packed_into_block = building_blocks.Block(leaf_called_graphs, transformed_result) tf_generated, _ = create_tensorflow_representing_block(packed_into_block) return tf_generated, True