예제 #1
0
    def test_raises_on_compiled_computation(self):
        def fn(x):
            return x

        comp = _create_compiled_computation(fn, tf.int32)
        with self.assertRaises(TypeError):
            proto_transformations.prune_tensorflow_proto(comp)
예제 #2
0
 def test_prune_does_not_change_exeuction(self):
   proto = _create_proto_with_unnecessary_op()
   reduced_proto = proto_transformations.prune_tensorflow_proto(proto)
   for k in range(5):
     self.assertEqual(
         test_utils.run_tensorflow(proto, k),
         test_utils.run_tensorflow(reduced_proto, k))
예제 #3
0
 def test_does_not_reduce_no_unnecessary_ops(self):
   comp = building_block_factory.create_compiled_identity(tf.int32)
   pruned = building_blocks.CompiledComputation(
       proto_transformations.prune_tensorflow_proto(comp.proto))
   ops_before = building_block_analysis.count_tensorflow_ops_in(comp)
   ops_after = building_block_analysis.count_tensorflow_ops_in(pruned)
   self.assertEqual(ops_before, ops_after)
예제 #4
0
 def test_reduces_unnecessary_ops(self):
   proto = _create_proto_with_unnecessary_op()
   comp = building_blocks.CompiledComputation(proto)
   reduced_proto = proto_transformations.prune_tensorflow_proto(proto)
   reduced_comp = building_blocks.CompiledComputation(reduced_proto)
   ops_before = building_block_analysis.count_tensorflow_ops_in(comp)
   ops_after = building_block_analysis.count_tensorflow_ops_in(reduced_comp)
   self.assertLess(ops_after, ops_before)
예제 #5
0
    def test_does_not_reduce_no_unnecessary_ops(self):
        def fn(x):
            return x

        comp = _create_compiled_computation(fn, tf.int32)
        pruned = computation_building_blocks.CompiledComputation(
            proto_transformations.prune_tensorflow_proto(comp.proto))
        ops_before = computation_building_block_utils.count_tensorflow_ops_in(
            comp)
        ops_after = computation_building_block_utils.count_tensorflow_ops_in(
            pruned)
        self.assertEqual(ops_before, ops_after)
예제 #6
0
    def test_reduces_unnecessary_ops(self):
        def bad_fn(x):
            _ = tf.constant(0)
            return x

        comp = _create_compiled_computation(bad_fn, tf.int32)
        ops_before = building_block_analysis.count_tensorflow_ops_in(comp)
        reduced_proto = proto_transformations.prune_tensorflow_proto(
            comp.proto)
        reduced_comp = building_blocks.CompiledComputation(reduced_proto)
        ops_after = building_block_analysis.count_tensorflow_ops_in(
            reduced_comp)
        self.assertLess(ops_after, ops_before)
예제 #7
0
    def test_prune_does_not_change_exeuction(self):
        def bad_fn(x):
            _ = tf.constant(0)
            return x

        comp = _create_compiled_computation(bad_fn, tf.int32)
        reduced_proto = proto_transformations.prune_tensorflow_proto(
            comp.proto)
        reduced_comp = building_blocks.CompiledComputation(reduced_proto)

        orig_executable = computation_wrapper_instances.building_block_to_computation(
            comp)
        reduced_executable = computation_wrapper_instances.building_block_to_computation(
            reduced_comp)
        for k in range(5):
            self.assertEqual(orig_executable(k), reduced_executable(k))
예제 #8
0
 def test_raises_on_none(self):
     with self.assertRaises(TypeError):
         proto_transformations.prune_tensorflow_proto(None)
예제 #9
0
 def test_raises_on_compiled_computation(self):
   comp = building_block_factory.create_compiled_identity(tf.int32)
   with self.assertRaises(TypeError):
     proto_transformations.prune_tensorflow_proto(comp)