def test_prune_does_not_change_exeuction(self): proto = _create_proto_with_unnecessary_op() reduced_proto = tensorflow_computation_transformations.prune_tensorflow_proto( proto) for k in range(5): self.assertEqual(test_utils.run_tensorflow(proto, k), test_utils.run_tensorflow(reduced_proto, k))
def test_does_not_reduce_no_unnecessary_ops(self): comp = building_block_factory.create_compiled_identity(tf.int32) pruned = building_blocks.CompiledComputation( tensorflow_computation_transformations.prune_tensorflow_proto( comp.proto)) ops_before = building_block_analysis.count_tensorflow_ops_in(comp) ops_after = building_block_analysis.count_tensorflow_ops_in(pruned) self.assertEqual(ops_before, ops_after)
def test_reduces_unnecessary_ops(self): proto = _create_proto_with_unnecessary_op() comp = building_blocks.CompiledComputation(proto) reduced_proto = tensorflow_computation_transformations.prune_tensorflow_proto( proto) reduced_comp = building_blocks.CompiledComputation(reduced_proto) ops_before = building_block_analysis.count_tensorflow_ops_in(comp) ops_after = building_block_analysis.count_tensorflow_ops_in(reduced_comp) self.assertLess(ops_after, ops_before)
def test_raises_on_compiled_computation(self): tensor_type = computation_types.TensorType(tf.int32) comp = building_block_factory.create_compiled_identity(tensor_type) with self.assertRaises(TypeError): tensorflow_computation_transformations.prune_tensorflow_proto(comp)
def test_raises_on_none(self): with self.assertRaises(TypeError): tensorflow_computation_transformations.prune_tensorflow_proto(None)