def test_does_not_reduce_no_unnecessary_ops(self): comp = building_block_factory.create_compiled_identity(tf.int32) pruned = building_blocks.CompiledComputation( proto_transformations.prune_tensorflow_proto(comp.proto)) ops_before = building_block_analysis.count_tensorflow_ops_in(comp) ops_after = building_block_analysis.count_tensorflow_ops_in(pruned) self.assertEqual(ops_before, ops_after)
def test_reduces_unnecessary_ops(self): proto = _create_proto_with_unnecessary_op() comp = building_blocks.CompiledComputation(proto) reduced_proto = proto_transformations.prune_tensorflow_proto(proto) reduced_comp = building_blocks.CompiledComputation(reduced_proto) ops_before = building_block_analysis.count_tensorflow_ops_in(comp) ops_after = building_block_analysis.count_tensorflow_ops_in(reduced_comp) self.assertLess(ops_after, ops_before)
def test_reduces_unnecessary_ops(self): def bad_fn(x): _ = tf.constant(0) return x comp = _create_compiled_computation(bad_fn, tf.int32) ops_before = building_block_analysis.count_tensorflow_ops_in(comp) reduced_proto = proto_transformations.prune_tensorflow_proto( comp.proto) reduced_comp = building_blocks.CompiledComputation(reduced_proto) ops_after = building_block_analysis.count_tensorflow_ops_in( reduced_comp) self.assertLess(ops_after, ops_before)
def _count_tf_ops(inner_comp): if isinstance( inner_comp, building_blocks.CompiledComputation ) and inner_comp.proto.WhichOneof('computation') == 'tensorflow': total_tf_ops[0] += building_block_analysis.count_tensorflow_ops_in( inner_comp) return inner_comp, False
def test_tensorflow_op_count_doubles_number_of_ops_in_two_tuple(self): integer_identity = building_block_factory.create_compiled_identity(tf.int32) node_tf_op_count = building_block_analysis.count_tensorflow_ops_in( integer_identity) tf_tuple = building_blocks.Tuple([integer_identity, integer_identity]) tree_tf_op_count = tree_analysis.count_tensorflow_ops_under(tf_tuple) self.assertEqual(tree_tf_op_count, 2 * node_tf_op_count)
def test_single_tensorflow_node_count_agrees_with_node_count(self): integer_identity = building_block_factory.create_compiled_identity(tf.int32) node_tf_op_count = building_block_analysis.count_tensorflow_ops_in( integer_identity) tree_tf_op_count = tree_analysis.count_tensorflow_ops_under( integer_identity) self.assertEqual(node_tf_op_count, tree_tf_op_count)
def test_counts_correct_number_of_ops_swith_function(self): @computations.tf_computation( computation_types.TensorType(tf.int32, shape=[])) def foo(x): @tf.function def bar(x): return x + 1 return bar(bar(x)) building_block = foo.to_building_block() tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in( building_block) self.assertEqual(tf_ops_in_graph, 6)
def test_counts_correct_number_of_ops_swith_function(self): @computations.tf_computation( computation_types.TensorType(tf.int32, shape=[])) def foo(x): @tf.function def bar(x): return x + 1 return bar(bar(x)) proto = computation_impl.ComputationImpl.get_proto(foo) building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in( building_block) self.assertEqual(tf_ops_in_graph, 6)
def test_counts_correct_number_of_ops_simple_case(self): with tf.Graph().as_default() as g: a = tf.constant(0) b = tf.constant(1) c = a + b _, result_binding = tensorflow_utils.capture_result_from_graph(c, g) packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def()) function_type = computation_types.FunctionType(None, tf.int32) proto = pb.Computation( type=type_serialization.serialize_type(function_type), tensorflow=pb.TensorFlow( graph_def=packed_graph_def, parameter=None, result=result_binding)) building_block = building_blocks.ComputationBuildingBlock.from_proto(proto) tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in( building_block) self.assertEqual(tf_ops_in_graph, 3)
def test_counts_correct_number_of_ops_swith_function(self): @computations.tf_computation( computation_types.TensorType(tf.int32, shape=[])) def foo(x): @tf.function def bar(x): return x + 1 return bar(bar(x)) building_block = foo.to_building_block() tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in( building_block) # Exepect 7 ops: # Inside the tf.function: # - one constant # - one addition # - one identity on the result # Inside the tff_computation: # - one placeholder # - two partition calls # - one identity on the tff_computation result self.assertEqual(tf_ops_in_graph, 7)
def test_counts_correct_number_of_ops_with_function(self): @tf.function def add_one(x): return x + 1 with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', tf.int32, graph) result = add_one(add_one(parameter_value)) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(tf.int32, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) proto = pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow) building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in( building_block) # Expect 7 ops: # Inside the tf.function: # - one constant # - one addition # - one identity on the result # Inside the tff_computation: # - one placeholders (one for the argument) # - two partition calls # - one identity on the tff_computation result self.assertEqual(tf_ops_in_graph, 7)
def _count_tf_ops(inner_comp): nonlocal count_ops if (inner_comp.is_compiled_computation() and inner_comp.proto.WhichOneof('computation') == 'tensorflow'): count_ops += building_block_analysis.count_tensorflow_ops_in( inner_comp)
def test_raises_on_reference(self): ref = building_blocks.Reference('x', tf.int32) with self.assertRaises(ValueError): building_block_analysis.count_tensorflow_ops_in(ref)
def test_raises_on_none(self): with self.assertRaises(TypeError): building_block_analysis.count_tensorflow_ops_in(None)