def test_prune_does_not_change_exeuction(self):
     proto = _create_proto_with_unnecessary_op()
     reduced_proto = tensorflow_computation_transformations.prune_tensorflow_proto(
         proto)
     for k in range(5):
         self.assertEqual(test_utils.run_tensorflow(proto, k),
                          test_utils.run_tensorflow(reduced_proto, k))
 def test_does_not_reduce_no_unnecessary_ops(self):
     comp = building_block_factory.create_compiled_identity(tf.int32)
     pruned = building_blocks.CompiledComputation(
         tensorflow_computation_transformations.prune_tensorflow_proto(
             comp.proto))
     ops_before = building_block_analysis.count_tensorflow_ops_in(comp)
     ops_after = building_block_analysis.count_tensorflow_ops_in(pruned)
     self.assertEqual(ops_before, ops_after)
 def test_reduces_unnecessary_ops(self):
   proto = _create_proto_with_unnecessary_op()
   comp = building_blocks.CompiledComputation(proto)
   reduced_proto = tensorflow_computation_transformations.prune_tensorflow_proto(
       proto)
   reduced_comp = building_blocks.CompiledComputation(reduced_proto)
   ops_before = building_block_analysis.count_tensorflow_ops_in(comp)
   ops_after = building_block_analysis.count_tensorflow_ops_in(reduced_comp)
   self.assertLess(ops_after, ops_before)
 def test_raises_on_compiled_computation(self):
     tensor_type = computation_types.TensorType(tf.int32)
     comp = building_block_factory.create_compiled_identity(tensor_type)
     with self.assertRaises(TypeError):
         tensorflow_computation_transformations.prune_tensorflow_proto(comp)
 def test_raises_on_none(self):
     with self.assertRaises(TypeError):
         tensorflow_computation_transformations.prune_tensorflow_proto(None)