def count_tensorflow_variables_under(comp): """Counts total TF variables in any TensorFlow computations under `comp`. Notice that this function is designed for the purpose of instrumentation, in particular to check the size and constituents of the TensorFlow artifacts generated. Args: comp: Instance of `building_blocks.ComputationBuildingBlock` whose TF variables we wish to count. Returns: `integer` count of number of TF variables present in any `building_blocks.CompiledComputation` of the TensorFlow variety under `comp`. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) # TODO(b/129791812): Cleanup Python 2 and 3 compatibility total_tf_vars = [0] def _count_tf_vars(inner_comp): if (isinstance(inner_comp, building_blocks.CompiledComputation) and inner_comp.proto.WhichOneof('computation') == 'tensorflow'): total_tf_vars[ 0] += building_block_analysis.count_tensorflow_variables_in( inner_comp) return inner_comp, False transformation_utils.transform_postorder(comp, _count_tf_vars) return total_tf_vars[0]
def check_intrinsics_whitelisted_for_reduction(comp): """Checks whitelist of intrinsics reducible to aggregate or broadcast. Args: comp: Instance of `building_blocks.ComputationBuildingBlock` to check for presence of intrinsics not currently immediately reducible to `FEDERATED_AGGREGATE` or `FEDERATED_BROADCAST`, or local processing. Raises: ValueError: If we encounter an intrinsic under `comp` that is not whitelisted as currently reducible. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) uri_whitelist = ( intrinsic_defs.FEDERATED_AGGREGATE.uri, intrinsic_defs.FEDERATED_APPLY.uri, intrinsic_defs.FEDERATED_BROADCAST.uri, intrinsic_defs.FEDERATED_MAP.uri, intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri, intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri, intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri, intrinsic_defs.FEDERATED_ZIP_AT_SERVER.uri, intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS.uri, ) def _check_whitelisted(comp): if (isinstance(comp, building_blocks.Intrinsic) and comp.uri not in uri_whitelist): raise ValueError( 'Encountered an Intrinsic not currently reducible to aggregate or ' 'broadcast, the intrinsic {}'.format( comp.compact_representation())) return comp, False transformation_utils.transform_postorder(comp, _check_whitelisted)
def test_compile_computation(self): @computations.federated_computation([ computation_types.FederatedType(tf.float32, placements.CLIENTS), computation_types.FederatedType(tf.float32, placements.SERVER, True) ]) def foo(temperatures, threshold): return intrinsics.federated_sum( intrinsics.federated_map( computations.tf_computation( lambda x, y: tf.cast(tf.greater(x, y), tf.int32), [tf.float32, tf.float32]), [temperatures, intrinsics.federated_broadcast(threshold)])) pipeline = compiler_pipeline.CompilerPipeline( context_stack_impl.context_stack) compiled_foo = pipeline.compile(foo) def _not_federated_sum(x): if isinstance(x, building_blocks.Intrinsic): self.assertNotEqual(x.uri, intrinsic_defs.FEDERATED_SUM.uri) return x, False transformation_utils.transform_postorder( building_blocks.ComputationBuildingBlock.from_proto( computation_impl.ComputationImpl.get_proto(compiled_foo)), _not_federated_sum)
def _check_no_functional_symbol_bindings(comp): """Encodes condition for completeness of direct extraction of calls. After checking this condition, all functions which are semantically called (IE, functions which will be invoked eventually by running the computation) are called directly, and we can simply extract them by pattern-matching on `building_blocks.Call`. Args: comp: Instance of `building_blocks.ComputationBuildingBlock` to check for lack of functional symbol bindings. Raises: ValueError: If `comp` has symbols bound to computations with type trees containing functional types. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) def _check_for_bindings(comp_to_check): if comp_to_check.is_block(): for name, local in comp_to_check.locals: if type_analysis.contains(local.type_signature, lambda x: x.is_function()): raise ValueError( 'We make the assumption when reducing to ' 'call-dominant form that there are no symbols bound ' 'to computations with functional type; encountered ' 'the computation {c} of type {t} bound to symbol ' '{s}. Failure here indicates an internal error in ' 'the construction of call-dominant form.'.format( c=local, t=local.type_signature, s=name)) return comp, False transformation_utils.transform_postorder(comp, _check_for_bindings)
def _get_unbound_ref(block): """Helper to get unbound ref name and type spec if it exists in `block`.""" all_unbound_refs = transformation_utils.get_map_of_unbound_references(block) top_level_unbound_ref = all_unbound_refs[block] num_unbound_refs = len(top_level_unbound_ref) if num_unbound_refs == 0: return None elif num_unbound_refs > 1: raise ValueError('`create_tensorflow_representing_block` must be passed ' 'a block with at most a single unbound reference; ' 'encountered the block {} with {} unbound ' 'references.'.format(block, len(top_level_unbound_ref))) unbound_ref_name = top_level_unbound_ref.pop() top_level_type_spec = None def _get_unbound_ref_type_spec(inner_comp): if (inner_comp.is_reference() and inner_comp.name == unbound_ref_name): nonlocal top_level_type_spec top_level_type_spec = inner_comp.type_signature return inner_comp, False transformation_utils.transform_postorder(block, _get_unbound_ref_type_spec) return building_blocks.Reference(unbound_ref_name, top_level_type_spec)
def check_has_single_placement(comp, single_placement): """Checks that the AST of `comp` contains only `single_placement`. Args: comp: Instance of `building_blocks.ComputationBuildingBlock`. single_placement: Instance of `placement_literals.PlacementLiteral` which should be the only placement present under `comp`. Raises: ValueError: If the AST under `comp` contains any `computation_types.FederatedType` other than `single_placement`. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) py_typecheck.check_type(single_placement, placement_literals.PlacementLiteral) def _check_single_placement(comp): """Checks that the placement in `type_spec` matches `single_placement`.""" if (isinstance(comp.type_signature, computation_types.FederatedType) and comp.type_signature.placement != single_placement): raise ValueError( 'Comp contains a placement other than {}; ' 'placement {} on comp {} inside the structure. '.format( single_placement, comp.type_signature.placement, comp.compact_representation())) return comp, False transformation_utils.transform_postorder(comp, _check_single_placement)
def dedupe_and_merge_tuple_intrinsics(comp, uri): r"""Merges tuples of called intrinsics into one called intrinsic.""" # TODO(b/147359721): The application of the function below is a workaround to # a known pattern preventing TFF from deduplicating, effectively because tree # equality won't determine that [a, a][0] and [a, a][1] are actually the same # thing. A fuller fix is planned, but requires increasing the invariants # respected further up the TFF compilation pipelines. That is, in order to # reason about sufficiency of our ability to detect duplicates at this layer, # we would very much prefer to be operating in the subset of TFF effectively # representing local computation. def _remove_selection_from_block_holding_tuple(comp): """Reduces selection from a block holding a tuple.""" if (comp.is_selection() and comp.source.is_block() and comp.source.result.is_struct()): if comp.index is None: names = [ x[0] for x in anonymous_tuple.iter_elements(comp.source.type_signature) ] index = names.index(comp.name) else: index = comp.index return building_blocks.Block(comp.source.locals, comp.source.result[index]), True return comp, False comp, _ = transformation_utils.transform_postorder( comp, _remove_selection_from_block_holding_tuple) transform_spec = tree_transformations.MergeTupleIntrinsics(comp, uri) dedupe_and_merger = RemoveDuplicatesAndApplyTransform(comp, transform_spec) return transformation_utils.transform_postorder(comp, dedupe_and_merger.transform)
def _check_parameters_for_tf_block_generation(block): """Helper to validate parameters for parsing block locals into TF graphs.""" py_typecheck.check_type(block, building_blocks.Block) for _, comp in block.locals: if not (isinstance(comp, building_blocks.Call) and isinstance(comp.function, building_blocks.CompiledComputation)): raise ValueError( 'create_tensorflow_representing_block may only be called ' 'on a block whose local variables are all bound to ' 'called TensorFlow computations; encountered a local ' 'bound to {}'.format(comp)) def _check_contains_only_refs_sels_and_tuples(inner_comp): if not isinstance(inner_comp, (building_blocks.Reference, building_blocks.Selection, building_blocks.Tuple)): raise ValueError( 'create_tensorflow_representing_block may only be called ' 'on a block whose result contains only Selections, ' 'Tuples and References; encountered the building block ' '{}.'.format(inner_comp)) return inner_comp, False transformation_utils.transform_postorder( block.result, _check_contains_only_refs_sels_and_tuples)
def _visit_postorder(comp, function): py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) def _function(inner_comp): function(inner_comp) return inner_comp, False transformation_utils.transform_postorder(comp, _function)
def get_uri_for_all_called_intrinsics(comp): existing_uri = set() def _update(comp): if building_block_analysis.is_called_intrinsic(comp, uri): existing_uri.add(comp.function.uri) return comp, False transformation_utils.transform_postorder(comp, _update) return existing_uri
def _visit_postorder( tree: building_blocks.ComputationBuildingBlock, function: Callable[[building_blocks.ComputationBuildingBlock], None]): py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock) def _visit(building_block): function(building_block) return building_block, False transformation_utils.transform_postorder(tree, _visit)
def assertNoLambdasOrBlocks(self, comp): def _transform(comp): if (isinstance(comp, building_blocks.Call) and isinstance( comp.function, building_blocks.Lambda)) or isinstance( comp, building_blocks.Block): raise AssertionError( 'Encountered disallowed computation: {}'.format( comp.compact_representation())) return comp, True transformation_utils.transform_postorder(comp, _transform)
def _inline_functions(comp): function_type_reference_names = [] def _populate_function_type_ref_names(comp): if comp.is_reference() and comp.type_signature.is_function(): function_type_reference_names.append(comp.name) return comp, False transformation_utils.transform_postorder(comp, _populate_function_type_ref_names) return tree_transformations.inline_block_locals( comp, variable_names=set(function_type_reference_names))
def test_aggregate_with_selection_from_block_by_name_results_in_single_aggregate( self): data = building_blocks.Reference( 'a', computation_types.FederatedType(tf.int32, placement_literals.CLIENTS)) tup_of_data = building_blocks.Tuple([('a', data), ('b', data)]) block_holding_tup = building_blocks.Block([], tup_of_data) index_0_from_block = building_blocks.Selection( source=block_holding_tup, name='a') index_1_from_block = building_blocks.Selection( source=block_holding_tup, name='b') result = building_blocks.Data('aggregation_result', tf.int32) zero = building_blocks.Data('zero', tf.int32) accumulate = building_blocks.Lambda('accumulate_param', [tf.int32, tf.int32], result) merge = building_blocks.Lambda('merge_param', [tf.int32, tf.int32], result) report = building_blocks.Lambda('report_param', tf.int32, result) called_intrinsic0 = building_block_factory.create_federated_aggregate( index_0_from_block, zero, accumulate, merge, report) called_intrinsic1 = building_block_factory.create_federated_aggregate( index_1_from_block, zero, accumulate, merge, report) calls = building_blocks.Tuple((called_intrinsic0, called_intrinsic1)) comp = calls deduped_and_merged_comp, deduped_modified = transformations.dedupe_and_merge_tuple_intrinsics( comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) self.assertTrue(deduped_modified) fed_agg = [] def _find_called_federated_aggregate(comp): if (isinstance(comp, building_blocks.Call) and isinstance(comp.function, building_blocks.Intrinsic) and comp.function.uri == intrinsic_defs.FEDERATED_AGGREGATE.uri): fed_agg.append(comp.function) return comp, False transformation_utils.transform_postorder( deduped_and_merged_comp, _find_called_federated_aggregate) self.assertLen(fed_agg, 1) self.assertEqual( fed_agg[0].type_signature.parameter[0].compact_representation(), '{<int32>}@CLIENTS')
def test_ops_not_duplicated_in_resulting_tensorflow(self): def _construct_block_and_inlined_tuple(k): concrete_int_type = computation_types.TensorType(tf.int32) concrete_int = building_block_factory.create_tensorflow_constant( concrete_int_type, 1) first_tf_id_type = computation_types.TensorType(tf.int32) first_tf_id = building_block_factory.create_compiled_identity( first_tf_id_type) called_tf_id = building_blocks.Call(first_tf_id, concrete_int) for _ in range(k): # Simulating large TF computation called_tf_id = building_blocks.Call(first_tf_id, called_tf_id) ref_to_call = building_blocks.Reference( 'call', called_tf_id.type_signature) block_locals = [('call', called_tf_id)] block = building_blocks.Block( block_locals, building_blocks.Tuple([ref_to_call, ref_to_call])) inlined_tuple = building_blocks.Tuple([called_tf_id, called_tf_id]) return block, inlined_tuple block_with_5_ids, inlined_tuple_with_5_ids = _construct_block_and_inlined_tuple( 5) block_with_10_ids, inlined_tuple_with_10_ids = _construct_block_and_inlined_tuple( 10) tf_representing_block_with_5_ids, _ = transformations.create_tensorflow_representing_block( block_with_5_ids) tf_representing_block_with_10_ids, _ = transformations.create_tensorflow_representing_block( block_with_10_ids) block_ops_with_5_ids = tree_analysis.count_tensorflow_ops_under( tf_representing_block_with_5_ids) block_ops_with_10_ids = tree_analysis.count_tensorflow_ops_under( tf_representing_block_with_10_ids) parser_callable = tree_to_cc_transformations.TFParser() naively_generated_tf_with_5_ids, _ = transformation_utils.transform_postorder( inlined_tuple_with_5_ids, parser_callable) naively_generated_tf_with_10_ids, _ = transformation_utils.transform_postorder( inlined_tuple_with_10_ids, parser_callable) tuple_ops_with_5_ids = tree_analysis.count_tensorflow_ops_under( naively_generated_tf_with_5_ids) tuple_ops_with_10_ids = tree_analysis.count_tensorflow_ops_under( naively_generated_tf_with_10_ids) # asserting that block ops are linear in k with slope 1. self.assertEqual((block_ops_with_10_ids - block_ops_with_5_ids) / 5, 1) # asserting that tuple ops are linear in k with slope 2. self.assertEqual((tuple_ops_with_10_ids - tuple_ops_with_5_ids) / 5, 2)
def _generate_simple_tensorflow(comp): tf_parser_callable = tree_to_cc_transformations.TFParser() comp, _ = tree_transformations.insert_called_tf_identity_at_leaves( comp) comp, _ = transformation_utils.transform_postorder( comp, tf_parser_callable) return comp
def _replace_selections( bb: building_blocks.ComputationBuildingBlock, ref_name: str, path_to_replacement: Dict[Tuple[int, ...], building_blocks.ComputationBuildingBlock], ) -> building_blocks.ComputationBuildingBlock: """Identifies selection pattern and replaces with new binding. Note that this function is somewhat brittle in that it only replaces AST fragments of exactly the form `ref_name[i][j][k]` (for path `(i, j, k)`). That is, it will not detect `let x = ref_name[i][j] in x[k]` or similar. This is only sufficient because, at the point this function has been called, called lambdas have been replaced with blocks and blocks have been inlined, so there are no reference chains that must be traced back. Any reference which would eventually resolve to a part of a lambda's parameter instead refers to the parameter directly. Similarly, selections from tuples have been collapsed. The remaining concern would be selections via calls to opaque compiled compuations, which we error on. Args: bb: Instance of `building_blocks.ComputationBuildingBlock` in which we wish to replace the selections from reference `ref_name` with any path in `paths_to_replacement` with the corresponding building block. ref_name: Name of the reference to look for selectiosn from. path_to_replacement: A map from selection path to the building block with which to replace the selection. Note; it is not valid to specify overlapping selection paths (where one path encompasses another). Returns: A possibly transformed version of `bb` with nodes matching the selection patterns replaced. """ def _replace(inner_bb): # Start with an empty selection path = [] selection = inner_bb while selection.is_selection(): path.append(selection.as_index()) selection = selection.source # In ASTs like x[0][1], we'll see the last (outermost) selection first. path.reverse() path = tuple(path) if (selection.is_reference() and selection.name == ref_name and path in path_to_replacement): return path_to_replacement[path], True if (inner_bb.is_call() and inner_bb.function.is_compiled_computation() and inner_bb.argument is not None and inner_bb.argument.is_reference() and inner_bb.argument.name == ref_name): raise ValueError( 'Encountered called graph on reference pattern in TFF ' 'AST; this means relying on pattern-matching when ' 'rebinding arguments may be insufficient. Ensure that ' 'arguments are rebound before decorating references ' 'with called identity graphs.') return inner_bb, False result, _ = transformation_utils.transform_postorder(bb, _replace) return result
def _generate_simple_tensorflow(comp): """Naively generates TensorFlow to represent `comp`.""" tf_parser_callable = tree_to_cc_transformations.TFParser() comp, _ = tree_transformations.insert_called_tf_identity_at_leaves(comp) comp, _ = transformation_utils.transform_postorder(comp, tf_parser_callable) return comp
def test_constructs_broadcast_of_tuple_with_one_element(self): called_intrinsic = test_utils.create_dummy_called_federated_broadcast() calls = building_blocks.Tuple((called_intrinsic, called_intrinsic)) comp = calls transformed_comp, modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics( comp, intrinsic_defs.FEDERATED_BROADCAST.uri) federated_broadcast = [] def _find_federated_broadcast(comp): if building_block_analysis.is_called_intrinsic( comp, intrinsic_defs.FEDERATED_BROADCAST.uri): federated_broadcast.append(comp) return comp, False transformation_utils.transform_postorder(transformed_comp, _find_federated_broadcast) self.assertTrue(modified) self.assertEqual( comp.compact_representation(), '<federated_broadcast(data),federated_broadcast(data)>') self.assertLen(federated_broadcast, 1) self.assertLen(federated_broadcast[0].type_signature.member, 1) self.assertEqual( transformed_comp.formatted_representation(), '(_var1 -> <\n' ' _var1[0],\n' ' _var1[0]\n' '>)((x -> <\n' ' x[0]\n' '>)((let\n' ' value=federated_broadcast(federated_apply(<\n' ' (arg -> <\n' ' arg\n' ' >),\n' ' <\n' ' data\n' ' >[0]\n' ' >))\n' ' in <\n' ' federated_map_all_equal(<\n' ' (arg -> arg[0]),\n' ' value\n' ' >)\n' '>)))')
def _extract_calls_and_blocks(comp): def _predicate(comp): return comp.is_call() block_extracter = tree_transformations.ExtractComputation(comp, _predicate) return transformation_utils.transform_postorder(comp, block_extracter.transform)
def check_allowed_ops( comp: building_blocks.ComputationBuildingBlock, allowed_op_names: FrozenSet[str] ) -> Tuple[building_blocks.ComputationBuildingBlock, bool]: """Checks any Tensorflow computation contains only allowed ops.""" transform_spec = VerifyAllowedOps(allowed_op_names) return transformation_utils.transform_postorder(comp, transform_spec.transform)
def check_disallowed_ops( comp: building_blocks.ComputationBuildingBlock, disallowed_op_names: FrozenSet[str] ) -> Tuple[building_blocks.ComputationBuildingBlock, bool]: """Raises error on disallowed ops in any Tensorflow computation.""" transform_spec = RaiseOnDisallowedOp(disallowed_op_names) return transformation_utils.transform_postorder(comp, transform_spec.transform)
def test_unwraps_block_with_empty_locals(self): input_data = building_blocks.Data('b', tf.int32) blk = building_blocks.Block([], input_data) data, modified = transformation_utils.transform_postorder( blk, self._unused_block_remover.transform) self.assertTrue(modified) self.assertEqual(data.compact_representation(), input_data.compact_representation())
def concatenate_function_outputs(first_function, second_function): """Constructs a new function concatenating the outputs of its arguments. Assumes that `first_function` and `second_function` already have unique names, and have declared parameters of the same type. The constructed function will bind its parameter to each of the parameters of `first_function` and `second_function`, and return the result of executing these functions in parallel and concatenating the outputs in a tuple. Args: first_function: Instance of `building_blocks.Lambda` whose result we wish to concatenate with the result of `second_function`. second_function: Instance of `building_blocks.Lambda` whose result we wish to concatenate with the result of `first_function`. Returns: A new instance of `building_blocks.Lambda` with unique names representing the computation described above. Raises: TypeError: If the arguments are not instances of `building_blocks.Lambda`, or declare parameters of different types. """ py_typecheck.check_type(first_function, building_blocks.Lambda) py_typecheck.check_type(second_function, building_blocks.Lambda) tree_analysis.check_has_unique_names(first_function) tree_analysis.check_has_unique_names(second_function) if first_function.parameter_type != second_function.parameter_type: raise TypeError( 'Must pass two functions which declare the same parameter ' 'type to `concatenate_function_outputs`; you have passed ' 'one function which declared a parameter of type {}, and ' 'another which declares a parameter of type {}'.format( first_function.type_signature, second_function.type_signature)) def _rename_first_function_arg(comp): if comp.is_reference() and comp.name == first_function.parameter_name: if comp.type_signature != second_function.parameter_type: raise AssertionError('{}, {}'.format( comp.type_signature, second_function.parameter_type)) return building_blocks.Reference(second_function.parameter_name, comp.type_signature), True return comp, False first_function, _ = transformation_utils.transform_postorder( first_function, _rename_first_function_arg) concatenated_function = building_blocks.Lambda( second_function.parameter_name, second_function.parameter_type, building_blocks.Struct([first_function.result, second_function.result])) renamed, _ = tree_transformations.uniquify_reference_names( concatenated_function) return renamed
def test_leaves_single_used_reference(self): blk = building_blocks.Block( [('x', building_blocks.Data('a', tf.int32))], building_blocks.Reference('x', tf.int32)) transformed_blk, modified = transformation_utils.transform_postorder( blk, self._unused_block_remover.transform) self.assertFalse(modified) self.assertEqual(transformed_blk.compact_representation(), blk.compact_representation())
def test_parameters_are_mapped_together(self): x_reference = building_blocks.Reference('x', tf.int32) x_lambda = building_blocks.Lambda('x', tf.int32, x_reference) y_reference = building_blocks.Reference('y', tf.int32) y_lambda = building_blocks.Lambda('y', tf.int32, y_reference) concatenated = transformations.concatenate_function_outputs( x_lambda, y_lambda) parameter_name = concatenated.parameter_name def _raise_on_other_name_reference(comp): if isinstance(comp, building_blocks.Reference) and comp.name != parameter_name: raise ValueError return comp, True tree_analysis.check_has_unique_names(concatenated) transformation_utils.transform_postorder(concatenated, _raise_on_other_name_reference)
def parse_tff_to_tf(comp): comp, _ = tree_transformations.insert_called_tf_identity_at_leaves(comp) parser_callable = tree_to_cc_transformations.TFParser() comp, _ = tree_transformations.replace_called_lambda_with_block(comp) comp, _ = tree_transformations.inline_block_locals(comp) comp, _ = tree_transformations.replace_selection_from_tuple_with_element(comp) new_comp, transformed = transformation_utils.transform_postorder( comp, parser_callable) return new_comp, transformed
def test_removes_nested_blocks_with_unused_reference(self): input_data = building_blocks.Data('b', tf.int32) blk = building_blocks.Block( [('x', building_blocks.Data('a', tf.int32))], input_data) higher_level_blk = building_blocks.Block([('y', input_data)], blk) data, modified = transformation_utils.transform_postorder( higher_level_blk, self._unused_block_remover.transform) self.assertTrue(modified) self.assertEqual(data.compact_representation(), input_data.compact_representation())
def test_leaves_lone_referenced_local(self): ref = building_blocks.Reference('y', tf.int32) blk = building_blocks.Block( [('x', building_blocks.Data('a', tf.int32)), ('y', building_blocks.Data('b', tf.int32))], ref) transformed_blk, modified = transformation_utils.transform_postorder( blk, self._unused_block_remover.transform) self.assertTrue(modified) self.assertEqual(transformed_blk.compact_representation(), '(let y=b in y)')
def count(comp, predicate=None): """Returns the number of computations in `comp` matching `predicate`. Args: comp: The computation to test. predicate: An optional Python function that takes a computation as a parameter and returns a boolean value. If `None`, all computations are counted. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) counter = [0] def _function(comp): if predicate is None or predicate(comp): counter[0] += 1 return comp, False transformation_utils.transform_postorder(comp, _function) return counter[0]