def test_raises_on_compiled_computation(self): def fn(x): return x comp = _create_compiled_computation(fn, tf.int32) with self.assertRaises(TypeError): proto_transformations.prune_tensorflow_proto(comp)
def test_prune_does_not_change_exeuction(self): proto = _create_proto_with_unnecessary_op() reduced_proto = proto_transformations.prune_tensorflow_proto(proto) for k in range(5): self.assertEqual( test_utils.run_tensorflow(proto, k), test_utils.run_tensorflow(reduced_proto, k))
def test_does_not_reduce_no_unnecessary_ops(self): comp = building_block_factory.create_compiled_identity(tf.int32) pruned = building_blocks.CompiledComputation( proto_transformations.prune_tensorflow_proto(comp.proto)) ops_before = building_block_analysis.count_tensorflow_ops_in(comp) ops_after = building_block_analysis.count_tensorflow_ops_in(pruned) self.assertEqual(ops_before, ops_after)
def test_reduces_unnecessary_ops(self): proto = _create_proto_with_unnecessary_op() comp = building_blocks.CompiledComputation(proto) reduced_proto = proto_transformations.prune_tensorflow_proto(proto) reduced_comp = building_blocks.CompiledComputation(reduced_proto) ops_before = building_block_analysis.count_tensorflow_ops_in(comp) ops_after = building_block_analysis.count_tensorflow_ops_in(reduced_comp) self.assertLess(ops_after, ops_before)
def test_does_not_reduce_no_unnecessary_ops(self): def fn(x): return x comp = _create_compiled_computation(fn, tf.int32) pruned = computation_building_blocks.CompiledComputation( proto_transformations.prune_tensorflow_proto(comp.proto)) ops_before = computation_building_block_utils.count_tensorflow_ops_in( comp) ops_after = computation_building_block_utils.count_tensorflow_ops_in( pruned) self.assertEqual(ops_before, ops_after)
def test_reduces_unnecessary_ops(self): def bad_fn(x): _ = tf.constant(0) return x comp = _create_compiled_computation(bad_fn, tf.int32) ops_before = building_block_analysis.count_tensorflow_ops_in(comp) reduced_proto = proto_transformations.prune_tensorflow_proto( comp.proto) reduced_comp = building_blocks.CompiledComputation(reduced_proto) ops_after = building_block_analysis.count_tensorflow_ops_in( reduced_comp) self.assertLess(ops_after, ops_before)
def test_prune_does_not_change_exeuction(self): def bad_fn(x): _ = tf.constant(0) return x comp = _create_compiled_computation(bad_fn, tf.int32) reduced_proto = proto_transformations.prune_tensorflow_proto( comp.proto) reduced_comp = building_blocks.CompiledComputation(reduced_proto) orig_executable = computation_wrapper_instances.building_block_to_computation( comp) reduced_executable = computation_wrapper_instances.building_block_to_computation( reduced_comp) for k in range(5): self.assertEqual(orig_executable(k), reduced_executable(k))
def test_raises_on_none(self): with self.assertRaises(TypeError): proto_transformations.prune_tensorflow_proto(None)
def test_raises_on_compiled_computation(self): comp = building_block_factory.create_compiled_identity(tf.int32) with self.assertRaises(TypeError): proto_transformations.prune_tensorflow_proto(comp)