Beispiel #1
0
 def test_does_not_reduce_no_unnecessary_ops(self):
   comp = building_block_factory.create_compiled_identity(tf.int32)
   pruned = building_blocks.CompiledComputation(
       proto_transformations.prune_tensorflow_proto(comp.proto))
   ops_before = building_block_analysis.count_tensorflow_ops_in(comp)
   ops_after = building_block_analysis.count_tensorflow_ops_in(pruned)
   self.assertEqual(ops_before, ops_after)
Beispiel #2
0
 def test_reduces_unnecessary_ops(self):
   proto = _create_proto_with_unnecessary_op()
   comp = building_blocks.CompiledComputation(proto)
   reduced_proto = proto_transformations.prune_tensorflow_proto(proto)
   reduced_comp = building_blocks.CompiledComputation(reduced_proto)
   ops_before = building_block_analysis.count_tensorflow_ops_in(comp)
   ops_after = building_block_analysis.count_tensorflow_ops_in(reduced_comp)
   self.assertLess(ops_after, ops_before)
Beispiel #3
0
    def test_reduces_unnecessary_ops(self):
        def bad_fn(x):
            _ = tf.constant(0)
            return x

        comp = _create_compiled_computation(bad_fn, tf.int32)
        ops_before = building_block_analysis.count_tensorflow_ops_in(comp)
        reduced_proto = proto_transformations.prune_tensorflow_proto(
            comp.proto)
        reduced_comp = building_blocks.CompiledComputation(reduced_proto)
        ops_after = building_block_analysis.count_tensorflow_ops_in(
            reduced_comp)
        self.assertLess(ops_after, ops_before)
Beispiel #4
0
 def _count_tf_ops(inner_comp):
     if isinstance(
             inner_comp, building_blocks.CompiledComputation
     ) and inner_comp.proto.WhichOneof('computation') == 'tensorflow':
         total_tf_ops[0] += building_block_analysis.count_tensorflow_ops_in(
             inner_comp)
     return inner_comp, False
Beispiel #5
0
 def test_tensorflow_op_count_doubles_number_of_ops_in_two_tuple(self):
   integer_identity = building_block_factory.create_compiled_identity(tf.int32)
   node_tf_op_count = building_block_analysis.count_tensorflow_ops_in(
       integer_identity)
   tf_tuple = building_blocks.Tuple([integer_identity, integer_identity])
   tree_tf_op_count = tree_analysis.count_tensorflow_ops_under(tf_tuple)
   self.assertEqual(tree_tf_op_count, 2 * node_tf_op_count)
Beispiel #6
0
 def test_single_tensorflow_node_count_agrees_with_node_count(self):
   integer_identity = building_block_factory.create_compiled_identity(tf.int32)
   node_tf_op_count = building_block_analysis.count_tensorflow_ops_in(
       integer_identity)
   tree_tf_op_count = tree_analysis.count_tensorflow_ops_under(
       integer_identity)
   self.assertEqual(node_tf_op_count, tree_tf_op_count)
Beispiel #7
0
    def test_counts_correct_number_of_ops_swith_function(self):
        @computations.tf_computation(
            computation_types.TensorType(tf.int32, shape=[]))
        def foo(x):
            @tf.function
            def bar(x):
                return x + 1

            return bar(bar(x))

        building_block = foo.to_building_block()
        tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in(
            building_block)
        self.assertEqual(tf_ops_in_graph, 6)
    def test_counts_correct_number_of_ops_swith_function(self):
        @computations.tf_computation(
            computation_types.TensorType(tf.int32, shape=[]))
        def foo(x):
            @tf.function
            def bar(x):
                return x + 1

            return bar(bar(x))

        proto = computation_impl.ComputationImpl.get_proto(foo)
        building_block = building_blocks.ComputationBuildingBlock.from_proto(
            proto)
        tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in(
            building_block)
        self.assertEqual(tf_ops_in_graph, 6)
  def test_counts_correct_number_of_ops_simple_case(self):

    with tf.Graph().as_default() as g:
      a = tf.constant(0)
      b = tf.constant(1)
      c = a + b

    _, result_binding = tensorflow_utils.capture_result_from_graph(c, g)

    packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
    function_type = computation_types.FunctionType(None, tf.int32)
    proto = pb.Computation(
        type=type_serialization.serialize_type(function_type),
        tensorflow=pb.TensorFlow(
            graph_def=packed_graph_def, parameter=None, result=result_binding))
    building_block = building_blocks.ComputationBuildingBlock.from_proto(proto)
    tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in(
        building_block)
    self.assertEqual(tf_ops_in_graph, 3)
    def test_counts_correct_number_of_ops_swith_function(self):
        @computations.tf_computation(
            computation_types.TensorType(tf.int32, shape=[]))
        def foo(x):
            @tf.function
            def bar(x):
                return x + 1

            return bar(bar(x))

        building_block = foo.to_building_block()
        tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in(
            building_block)
        # Exepect 7 ops:
        #    Inside the tf.function:
        #      - one constant
        #      - one addition
        #      - one identity on the result
        #    Inside the tff_computation:
        #      - one placeholder
        #      - two partition calls
        #      - one identity on the tff_computation result
        self.assertEqual(tf_ops_in_graph, 7)
Beispiel #11
0
    def test_counts_correct_number_of_ops_with_function(self):
        @tf.function
        def add_one(x):
            return x + 1

        with tf.Graph().as_default() as graph:
            parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
                'x', tf.int32, graph)
            result = add_one(add_one(parameter_value))

        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)
        type_signature = computation_types.FunctionType(tf.int32, result_type)
        tensorflow = pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
            parameter=parameter_binding,
            result=result_binding)
        proto = pb.Computation(
            type=type_serialization.serialize_type(type_signature),
            tensorflow=tensorflow)
        building_block = building_blocks.ComputationBuildingBlock.from_proto(
            proto)

        tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in(
            building_block)

        # Expect 7 ops:
        #    Inside the tf.function:
        #      - one constant
        #      - one addition
        #      - one identity on the result
        #    Inside the tff_computation:
        #      - one placeholders (one for the argument)
        #      - two partition calls
        #      - one identity on the tff_computation result
        self.assertEqual(tf_ops_in_graph, 7)
Beispiel #12
0
 def _count_tf_ops(inner_comp):
     nonlocal count_ops
     if (inner_comp.is_compiled_computation() and
             inner_comp.proto.WhichOneof('computation') == 'tensorflow'):
         count_ops += building_block_analysis.count_tensorflow_ops_in(
             inner_comp)
Beispiel #13
0
 def test_raises_on_reference(self):
     ref = building_blocks.Reference('x', tf.int32)
     with self.assertRaises(ValueError):
         building_block_analysis.count_tensorflow_ops_in(ref)
Beispiel #14
0
 def test_raises_on_none(self):
     with self.assertRaises(TypeError):
         building_block_analysis.count_tensorflow_ops_in(None)