def test_graph_and_eager_modes(test_class_or_method=None): """Decorator for generating graph and eager mode tests from a single test. Must be applied to subclasses of `parameterized.TestCase` (from absl/testing), or a method of such a subclass. When applied to a test method, this decorator results in the replacement of that method with a two new test methods, one executed in graph mode and the other in eager mode. When applied to a test class, all the methods in the class are affected. Args: test_class_or_method: the `TestCase` class or method to decorate. Returns: decorator: A generated TF `test_combinations` decorator, or if `test_class_or_method` is not `None`, the generated decorator applied to that function. """ decorator = test_combinations.generate( test_combinations.combine(mode=['graph', 'eager']), test_combinations=[EagerGraphCombination()]) if test_class_or_method: return decorator(test_class_or_method) return decorator
def test_graph_mode_only(test_class_or_method=None): """Decorator for ensuring tests run in graph mode. Must be applied to subclasses of `parameterized.TestCase` (from absl/testing), or a method of such a subclass. When applied to a test method, this decorator results in the replacement of that method with one new test method, executed in graph mode. When applied to a test class, all the methods in the class are affected. Args: test_class_or_method: the `TestCase` class or method to decorate. Returns: decorator: A generated TF `test_combinations` decorator, or if `test_class_or_method` is not `None`, the generated decorator applied to that function. Raises: SkipTest: Raised when not running in the TF backend. """ if JAX_MODE or NUMPY_MODE: raise unittest.SkipTest( 'Ignoring TF Graph Mode tests in non-TF backends.') decorator = test_combinations.generate( test_combinations.combine(mode=['graph']), test_combinations=[EagerGraphCombination()]) if test_class_or_method: return decorator(test_class_or_method) return decorator
def test_all_tf_execution_regimes(test_class_or_method=None): """Decorator for generating a collection of tests in various contexts. Must be applied to subclasses of `parameterized.TestCase` (from `absl/testing`), or a method of such a subclass. When applied to a test method, this decorator results in the replacement of that method with a collection of new test methods, each executed under a different set of context managers that control some aspect of the execution model. This decorator generates three test scenario combinations: 1. Eager mode with `tf.function` decorations enabled 2. Eager mode with `tf.function` decorations disabled 3. Graph mode (eveything) When applied to a test class, all the methods in the class are affected. Args: test_class_or_method: the `TestCase` class or method to decorate. Returns: decorator: A generated TF `test_combinations` decorator, or if `test_class_or_method` is not `None`, the generated decorator applied to that function. """ decorator = test_combinations.generate( (test_combinations.combine(mode='graph', tf_function='') + test_combinations.combine( mode='eager', tf_function=['', 'no_tf_function'])), test_combinations=[ EagerGraphCombination(), ExecuteFunctionsEagerlyCombination(), ]) if test_class_or_method: return decorator(test_class_or_method) return decorator