예제 #1
0
def test_graph_and_eager_modes(test_class_or_method=None):
    """Decorator for generating graph and eager mode tests from a single test.

  Must be applied to subclasses of `parameterized.TestCase` (from
  absl/testing), or a method of such a subclass.

  When applied to a test method, this decorator results in the replacement of
  that method with a two new test methods, one executed in graph mode and the
  other in eager mode.

  When applied to a test class, all the methods in the class are affected.

  Args:
    test_class_or_method: the `TestCase` class or method to decorate.

  Returns:
    decorator: A generated TF `test_combinations` decorator, or if
    `test_class_or_method` is not `None`, the generated decorator applied to
    that function.
  """
    decorator = test_combinations.generate(
        test_combinations.combine(mode=['graph', 'eager']),
        test_combinations=[EagerGraphCombination()])

    if test_class_or_method:
        return decorator(test_class_or_method)
    return decorator
예제 #2
0
def test_graph_mode_only(test_class_or_method=None):
    """Decorator for ensuring tests run in graph mode.

  Must be applied to subclasses of `parameterized.TestCase` (from
  absl/testing), or a method of such a subclass.

  When applied to a test method, this decorator results in the replacement of
  that method with one new test method, executed in graph mode.

  When applied to a test class, all the methods in the class are affected.

  Args:
    test_class_or_method: the `TestCase` class or method to decorate.

  Returns:
    decorator: A generated TF `test_combinations` decorator, or if
    `test_class_or_method` is not `None`, the generated decorator applied to
    that function.
  Raises:
    SkipTest: Raised when not running in the TF backend.
  """
    if JAX_MODE or NUMPY_MODE:
        raise unittest.SkipTest(
            'Ignoring TF Graph Mode tests in non-TF backends.')

    decorator = test_combinations.generate(
        test_combinations.combine(mode=['graph']),
        test_combinations=[EagerGraphCombination()])

    if test_class_or_method:
        return decorator(test_class_or_method)
    return decorator
예제 #3
0
def test_all_tf_execution_regimes(test_class_or_method=None):
  """Decorator for generating a collection of tests in various contexts.

  Must be applied to subclasses of `parameterized.TestCase` (from
  `absl/testing`), or a method of such a subclass.

  When applied to a test method, this decorator results in the replacement of
  that method with a collection of new test methods, each executed under a
  different set of context managers that control some aspect of the execution
  model. This decorator generates three test scenario combinations:

    1. Eager mode with `tf.function` decorations enabled
    2. Eager mode with `tf.function` decorations disabled
    3. Graph mode (eveything)

  When applied to a test class, all the methods in the class are affected.

  Args:
    test_class_or_method: the `TestCase` class or method to decorate.

  Returns:
    decorator: A generated TF `test_combinations` decorator, or if
    `test_class_or_method` is not `None`, the generated decorator applied to
    that function.
  """
  decorator = test_combinations.generate(
      (test_combinations.combine(mode='graph',
                                 tf_function='') +
       test_combinations.combine(
           mode='eager', tf_function=['', 'no_tf_function'])),
      test_combinations=[
          EagerGraphCombination(),
          ExecuteFunctionsEagerlyCombination(),
      ])

  if test_class_or_method:
    return decorator(test_class_or_method)
  return decorator