def test_graph_and_eager_modes(test_class_or_method=None): """Decorator for generating graph and eager mode tests from a single test. Must be applied to subclasses of `parameterized.TestCase` (from absl/testing), or a method of such a subclass. When applied to a test method, this decorator results in the replacement of that method with a two new test methods, one executed in graph mode and the other in eager mode. When applied to a test class, all the methods in the class are affected. Args: test_class_or_method: the `TestCase` class or method to decorate. Returns: decorator: A generated TF `test_combinations` decorator, or if `test_class_or_method` is not `None`, the generated decorator applied to that function. """ decorator = test_combinations.generate( test_combinations.combine(mode=['graph', 'eager']), test_combinations=[combinations.EagerGraphCombination()]) if test_class_or_method: return decorator(test_class_or_method) return decorator
def generate(combinations, test_combinations=()): # pylint: disable=g-doc-args,g-doc-return-or-yield """Distributed adapter of `tf.__internal__.test.combinations.generate`. All tests with distributed strategy should use this one instead of `tf.__internal__.test.combinations.generate`. This function has support of strategy combinations, GPU/TPU and multi worker support. See `tf.__internal__.test.combinations.generate` for usage. """ # pylint: enable=g-doc-args,g-doc-return-or-yield default_combinations = ( framework_combinations.EagerGraphCombination(), framework_combinations.TFVersionCombination(), ClusterCombination(), DistributionCombination(), GPUCombination(), TPUCombination(), ) # We apply our own decoration to handle multi worker tests before applying # framework.test_combinations.generate. The order is important since we need # framework.test_combinations.generate to apply all parameter modifiers first. combination_decorator = combinations_lib.generate( combinations, test_combinations=default_combinations + test_combinations) def decorator(test_method_or_class): if isinstance(test_method_or_class, type): # If it's a test class. class_object = test_method_or_class # Decorate each test method with _multi_worker_test. for name, test_method in six.iteritems( class_object.__dict__.copy()): if (name.startswith(unittest.TestLoader.testMethodPrefix) and isinstance(test_method, types.FunctionType)): setattr(class_object, name, _multi_worker_test(test_method)) return combination_decorator(class_object) else: return combination_decorator( _multi_worker_test(test_method_or_class)) return decorator
def test_all_tf_execution_regimes(test_class_or_method=None): """Decorator for generating a collection of tests in various contexts. Must be applied to subclasses of `parameterized.TestCase` (from `absl/testing`), or a method of such a subclass. When applied to a test method, this decorator results in the replacement of that method with a collection of new test methods, each executed under a different set of context managers that control some aspect of the execution model. This decorator generates three test scenario combinations: 1. Eager mode with `tf.function` decorations enabled 2. Eager mode with `tf.function` decorations disabled 3. Graph mode (eveything) When applied to a test class, all the methods in the class are affected. Args: test_class_or_method: the `TestCase` class or method to decorate. Returns: decorator: A generated TF `test_combinations` decorator, or if `test_class_or_method` is not `None`, the generated decorator applied to that function. """ decorator = test_combinations.generate( (test_combinations.combine(mode='graph', tf_function='enabled') + test_combinations.combine(mode='eager', tf_function=['enabled', 'disabled'])), test_combinations=[ combinations.EagerGraphCombination(), ExecuteFunctionsEagerlyCombination(), ]) if test_class_or_method: return decorator(test_class_or_method) return decorator
class TestCombinationsTest(test_case.TestCase, parameterized.TestCase): # # These tests check that the generated names are as expected. # def test_generated_test_case_names(self): expected_test_names = [ 'test_something_test_mode_eager_tffunction_disabled', 'test_something_test_mode_eager_tffunction_enabled', 'test_something_test_mode_graph_tffunction_enabled', ] for expected_test_name in expected_test_names: self.assertIn(expected_test_name, dir(PretendTestCaseClass)) def test_generated_parameterized_test_case_names(self): expected_test_names = [ 'test_something_p123_test_mode_eager_tffunction_disabled', 'test_something_p123_test_mode_eager_tffunction_enabled', 'test_something_p123_test_mode_graph_tffunction_enabled', ] for expected_test_name in expected_test_names: self.assertIn(expected_test_name, dir(PretendParameterizedTestCaseClass)) def test_generated_graph_and_eager_test_case_names(self): expected_test_names = [ 'test_something_test_mode_eager', 'test_something_test_mode_eager', 'test_something_test_mode_graph', ] for expected_test_name in expected_test_names: self.assertIn(expected_test_name, dir(PretendTestCaseClassGraphAndEagerOnly)) # # These tests ensure that the test generators do what they say on the tin. # @tf_test_combinations.generate( tf_test_combinations.combine(mode='graph'), test_combinations=[tf_combinations.EagerGraphCombination()]) def test_graph_mode_combination(self): self.assertFalse(context.executing_eagerly()) @tf_test_combinations.generate( tf_test_combinations.combine(mode='eager'), test_combinations=[tf_combinations.EagerGraphCombination()]) def test_eager_mode_combination(self): self.assertTrue(context.executing_eagerly()) @tf_test_combinations.generate( tf_test_combinations.combine(tf_function='enabled'), test_combinations=[ test_combinations.ExecuteFunctionsEagerlyCombination() ]) def test_tf_function_enabled_mode_combination(self): self.assertFalse(def_function.RUN_FUNCTIONS_EAGERLY) @tf_test_combinations.generate( tf_test_combinations.combine(tf_function='disabled'), test_combinations=[ test_combinations.ExecuteFunctionsEagerlyCombination() ]) def test_tf_function_disabled_mode_combination(self): self.assertTrue(def_function.RUN_FUNCTIONS_EAGERLY)