def _count_ops_parameterized_by_layers(k): inlined_tuple_with_k_layers = _construct_inlined_tuple(k) tf_representing_block_with_k_layers, _ = compiler_transformations.remove_duplicate_called_graphs( inlined_tuple_with_k_layers) block_ops_with_k_layers = tree_analysis.count_tensorflow_ops_under( tf_representing_block_with_k_layers) parser_callable = transformations.TFParser() naively_generated_tf_with_k_layers, _ = transformation_utils.transform_postorder( inlined_tuple_with_k_layers, parser_callable) naive_ops_with_k_layers = tree_analysis.count_tensorflow_ops_under( naively_generated_tf_with_k_layers) return block_ops_with_k_layers, naive_ops_with_k_layers
def test_ops_not_duplicated_in_resulting_tensorflow(self): def _construct_block_and_inlined_tuple(k): concrete_int = building_block_factory.create_tensorflow_constant( tf.int32, 1) first_tf_id = building_block_factory.create_compiled_identity( tf.int32) called_tf_id = building_blocks.Call(first_tf_id, concrete_int) for _ in range(k): # Simulating large TF computation called_tf_id = building_blocks.Call(first_tf_id, called_tf_id) ref_to_call = building_blocks.Reference( 'call', called_tf_id.type_signature) block_locals = [('call', called_tf_id)] block = building_blocks.Block( block_locals, building_blocks.Tuple([ref_to_call, ref_to_call])) inlined_tuple = building_blocks.Tuple([called_tf_id, called_tf_id]) return block, inlined_tuple block_with_5_ids, inlined_tuple_with_5_ids = _construct_block_and_inlined_tuple( 5) block_with_10_ids, inlined_tuple_with_10_ids = _construct_block_and_inlined_tuple( 10) tf_representing_block_with_5_ids, _ = compiler_transformations.create_tensorflow_representing_block( block_with_5_ids) tf_representing_block_with_10_ids, _ = compiler_transformations.create_tensorflow_representing_block( block_with_10_ids) block_ops_with_5_ids = tree_analysis.count_tensorflow_ops_under( tf_representing_block_with_5_ids) block_ops_with_10_ids = tree_analysis.count_tensorflow_ops_under( tf_representing_block_with_10_ids) parser_callable = transformations.TFParser() naively_generated_tf_with_5_ids, _ = transformation_utils.transform_postorder( inlined_tuple_with_5_ids, parser_callable) naively_generated_tf_with_10_ids, _ = transformation_utils.transform_postorder( inlined_tuple_with_10_ids, parser_callable) tuple_ops_with_5_ids = tree_analysis.count_tensorflow_ops_under( naively_generated_tf_with_5_ids) tuple_ops_with_10_ids = tree_analysis.count_tensorflow_ops_under( naively_generated_tf_with_10_ids) # asserting that block ops are linear in k with slope 1. self.assertEqual((block_ops_with_10_ids - block_ops_with_5_ids) / 5, 1) # asserting that tuple ops are linear in k with slope 2. self.assertEqual((tuple_ops_with_10_ids - tuple_ops_with_5_ids) / 5, 2)
def remove_duplicate_called_graphs(comp): """Deduplicates called graphs for a subset of TFF AST constructs. Args: comp: Instance of `building_blocks.ComputationBuildingBlock` whose called graphs we wish to deduplicate, according to `tree_analysis.trees_equal`. For `comp` to be eligible here, it must be either a lambda itself whose body contains no lambdas or blocks, or another computation containing no lambdas or blocks. This restriction is necessary because `remove_duplicate_called_graphs` makes no effort to ensure that it is not pulling references out of their defining scope, except for the case where `comp` is a lambda itself. This function exits early and logs a warning if this assumption is violated. Additionally, `comp` must contain only computations which can be represented in TensorFlow, IE, satisfy the type restriction in `type_utils.is_tensorflow_compatible_type`. Returns: Either a called instance of `building_blocks.CompiledComputation` or a `building_blocks.CompiledComputation` itself, depending on whether `comp` is of non-functional or functional type respectively. Additionally, returns a boolean to match the `transformation_utils.TransformSpec` pattern. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) tree_analysis.check_has_unique_names(comp) name_generator = building_block_factory.unique_name_generator(comp) if isinstance(comp, building_blocks.Lambda): comp_to_check = comp.result else: comp_to_check = comp if tree_analysis.count_types( comp_to_check, (building_blocks.Lambda, building_blocks.Block)) > 0: logging.warning( 'The preprocessors have failed to remove called lambdas ' 'and blocks; falling back to less efficient, but ' 'guaranteed, TensorFlow generation with computation %s.', comp) return comp, False leaf_called_graphs = [] def _pack_called_graphs_into_block(inner_comp): """Packs deduplicated bindings to called graphs in `leaf_called_graphs`.""" if (isinstance(inner_comp, building_blocks.Call) and isinstance(inner_comp.function, building_blocks.CompiledComputation)): for (name, x) in leaf_called_graphs: if tree_analysis.trees_equal(x, inner_comp): return building_blocks.Reference(name, inner_comp.type_signature), True new_name = next(name_generator) leaf_called_graphs.append((new_name, inner_comp)) return building_blocks.Reference(new_name, inner_comp.type_signature), True return inner_comp, False if isinstance(comp, building_blocks.Lambda): transformed_result, _ = transformation_utils.transform_postorder( comp.result, _pack_called_graphs_into_block) packed_into_block = building_blocks.Block(leaf_called_graphs, transformed_result) parsed, _ = create_tensorflow_representing_block(packed_into_block) tff_func = building_blocks.Lambda(comp.parameter_name, comp.parameter_type, parsed) tf_parser_callable = transformations.TFParser() comp, _ = transformations.insert_called_tf_identity_at_leaves(tff_func) tf_generated, _ = transformation_utils.transform_postorder( comp, tf_parser_callable) else: transformed_result, _ = transformation_utils.transform_postorder( comp, _pack_called_graphs_into_block) packed_into_block = building_blocks.Block(leaf_called_graphs, transformed_result) tf_generated, _ = create_tensorflow_representing_block(packed_into_block) return tf_generated, True
def _generate_simple_tensorflow(comp): tf_parser_callable = transformations.TFParser() comp, _ = transformations.insert_called_tf_identity_at_leaves(comp) comp, _ = transformation_utils.transform_postorder(comp, tf_parser_callable) return comp