def test_graph_and_eager_modes(test_class_or_method=None): """Decorator for generating graph and eager mode tests from a single test. Must be applied to subclasses of `parameterized.TestCase` (from absl/testing), or a method of such a subclass. When applied to a test method, this decorator results in the replacement of that method with a two new test methods, one executed in graph mode and the other in eager mode. When applied to a test class, all the methods in the class are affected. Args: test_class_or_method: the `TestCase` class or method to decorate. Returns: decorator: A generated TF `test_combinations` decorator, or if `test_class_or_method` is not `None`, the generated decorator applied to that function. """ decorator = test_combinations.generate( test_combinations.combine(mode=['graph', 'eager']), test_combinations=[combinations.EagerGraphCombination()]) if test_class_or_method: return decorator(test_class_or_method) return decorator
def generate(combinations, test_combinations=()): # pylint: disable=g-doc-args,g-doc-return-or-yield """Distributed adapter of `tf.__internal__.test.combinations.generate`. All tests with distributed strategy should use this one instead of `tf.__internal__.test.combinations.generate`. This function has support of strategy combinations, GPU/TPU and multi worker support. See `tf.__internal__.test.combinations.generate` for usage. """ # pylint: enable=g-doc-args,g-doc-return-or-yield default_combinations = ( framework_combinations.EagerGraphCombination(), framework_combinations.TFVersionCombination(), ClusterCombination(), DistributionCombination(), GPUCombination(), TPUCombination(), ) # We apply our own decoration to handle multi worker tests before applying # framework.test_combinations.generate. The order is important since we need # framework.test_combinations.generate to apply all parameter modifiers first. combination_decorator = combinations_lib.generate( combinations, test_combinations=default_combinations + test_combinations) def decorator(test_method_or_class): if isinstance(test_method_or_class, type): # If it's a test class. class_object = test_method_or_class # Decorate each test method with _multi_worker_test. for name, test_method in six.iteritems( class_object.__dict__.copy()): if (name.startswith(unittest.TestLoader.testMethodPrefix) and isinstance(test_method, types.FunctionType)): setattr(class_object, name, _multi_worker_test(test_method)) return combination_decorator(class_object) else: return combination_decorator( _multi_worker_test(test_method_or_class)) return decorator
def test_all_tf_execution_regimes(test_class_or_method=None): """Decorator for generating a collection of tests in various contexts. Must be applied to subclasses of `parameterized.TestCase` (from `absl/testing`), or a method of such a subclass. When applied to a test method, this decorator results in the replacement of that method with a collection of new test methods, each executed under a different set of context managers that control some aspect of the execution model. This decorator generates three test scenario combinations: 1. Eager mode with `tf.function` decorations enabled 2. Eager mode with `tf.function` decorations disabled 3. Graph mode (eveything) When applied to a test class, all the methods in the class are affected. Args: test_class_or_method: the `TestCase` class or method to decorate. Returns: decorator: A generated TF `test_combinations` decorator, or if `test_class_or_method` is not `None`, the generated decorator applied to that function. """ decorator = test_combinations.generate( (test_combinations.combine(mode='graph', tf_function='enabled') + test_combinations.combine(mode='eager', tf_function=['enabled', 'disabled'])), test_combinations=[ combinations.EagerGraphCombination(), ExecuteFunctionsEagerlyCombination(), ]) if test_class_or_method: return decorator(test_class_or_method) return decorator