def test_counts_correct_variables_with_function(self): @tf.function def add_one(x): with tf.init_scope(): y = tf.Variable(1) return x + y with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', tf.int32, graph) result = add_one(add_one(parameter_value)) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(tf.int32, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) proto = pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow) building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) tf_vars_in_graph = building_block_analysis.count_tensorflow_variables_in( building_block) self.assertEqual(tf_vars_in_graph, 1)
def _count_tf_vars(inner_comp): if (isinstance(inner_comp, building_blocks.CompiledComputation) and inner_comp.proto.WhichOneof('computation') == 'tensorflow'): total_tf_vars[ 0] += building_block_analysis.count_tensorflow_variables_in( inner_comp) return inner_comp, False
def test_tensorflow_op_count_doubles_number_of_ops_in_two_tuple(self): two_variable_comp = _create_two_variable_tensorflow() node_tf_variable_count = building_block_analysis.count_tensorflow_variables_in( two_variable_comp) tf_tuple = building_blocks.Struct([two_variable_comp, two_variable_comp]) tree_tf_variable_count = tree_analysis.count_tensorflow_variables_under( tf_tuple) self.assertEqual(tree_tf_variable_count, 2 * node_tf_variable_count)
def test_counts_correct_variables_with_function(self): @computations.tf_computation(tf.int32) def foo(x): y = tf.Variable(initial_value=0) @tf.function def bar(x): y.assign_add(1) return x + y, tf.shape(y) z = bar(x) return bar(z[0]) building_block = foo.to_building_block() tf_vars_in_graph = building_block_analysis.count_tensorflow_variables_in( building_block) self.assertEqual(tf_vars_in_graph, 1)
def test_counts_no_variables(self): with tf.Graph().as_default() as g: a = tf.constant(0) b = tf.constant(1) c = a + b _, result_binding = tensorflow_utils.capture_result_from_graph(c, g) packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def()) function_type = computation_types.FunctionType(None, tf.int32) proto = pb.Computation( type=type_serialization.serialize_type(function_type), tensorflow=pb.TensorFlow( graph_def=packed_graph_def, parameter=None, result=result_binding)) building_block = building_blocks.ComputationBuildingBlock.from_proto(proto) tf_vars_in_graph = building_block_analysis.count_tensorflow_variables_in( building_block) self.assertEqual(tf_vars_in_graph, 0)
def test_counts_correct_variables_with_function(self): @computations.tf_computation(tf.int32) def foo(x): y = tf.Variable(initial_value=0) @tf.function def bar(x): y.assign_add(1) return x + y, tf.shape(y) z = bar(x) return bar(z[0]) proto = computation_impl.ComputationImpl.get_proto(foo) building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) tf_vars_in_graph = building_block_analysis.count_tensorflow_variables_in( building_block) self.assertEqual(tf_vars_in_graph, 1)
def _count_tf_vars(inner_comp): nonlocal count_vars if (inner_comp.is_compiled_computation() and inner_comp.proto.WhichOneof('computation') == 'tensorflow'): count_vars += building_block_analysis.count_tensorflow_variables_in( inner_comp)
def test_raises_on_none(self): with self.assertRaises(TypeError): building_block_analysis.count_tensorflow_variables_in(None)