Ejemplo n.º 1
0
def test_graph_and_eager_modes(test_class_or_method=None):
  """Decorator for generating graph and eager mode tests from a single test.

  Must be applied to subclasses of `parameterized.TestCase` (from
  absl/testing), or a method of such a subclass.

  When applied to a test method, this decorator results in the replacement of
  that method with a two new test methods, one executed in graph mode and the
  other in eager mode.

  When applied to a test class, all the methods in the class are affected.

  Args:
    test_class_or_method: the `TestCase` class or method to decorate.

  Returns:
    decorator: A generated TF `test_combinations` decorator, or if
    `test_class_or_method` is not `None`, the generated decorator applied to
    that function.
  """
  decorator = test_combinations.generate(
      test_combinations.combine(mode=['graph', 'eager']),
      test_combinations=[combinations.EagerGraphCombination()])

  if test_class_or_method:
    return decorator(test_class_or_method)
  return decorator
Ejemplo n.º 2
0
def generate(combinations, test_combinations=()):
    # pylint: disable=g-doc-args,g-doc-return-or-yield
    """Distributed adapter of `tf.__internal__.test.combinations.generate`.

  All tests with distributed strategy should use this one instead of
  `tf.__internal__.test.combinations.generate`. This function has support of
  strategy combinations, GPU/TPU and multi worker support.

  See `tf.__internal__.test.combinations.generate` for usage.
  """
    # pylint: enable=g-doc-args,g-doc-return-or-yield
    default_combinations = (
        framework_combinations.EagerGraphCombination(),
        framework_combinations.TFVersionCombination(),
        ClusterCombination(),
        DistributionCombination(),
        GPUCombination(),
        TPUCombination(),
    )
    # We apply our own decoration to handle multi worker tests before applying
    # framework.test_combinations.generate. The order is important since we need
    # framework.test_combinations.generate to apply all parameter modifiers first.
    combination_decorator = combinations_lib.generate(
        combinations,
        test_combinations=default_combinations + test_combinations)

    def decorator(test_method_or_class):
        if isinstance(test_method_or_class, type):
            # If it's a test class.
            class_object = test_method_or_class
            # Decorate each test method with _multi_worker_test.
            for name, test_method in six.iteritems(
                    class_object.__dict__.copy()):
                if (name.startswith(unittest.TestLoader.testMethodPrefix)
                        and isinstance(test_method, types.FunctionType)):
                    setattr(class_object, name,
                            _multi_worker_test(test_method))
            return combination_decorator(class_object)
        else:
            return combination_decorator(
                _multi_worker_test(test_method_or_class))

    return decorator
Ejemplo n.º 3
0
def test_all_tf_execution_regimes(test_class_or_method=None):
  """Decorator for generating a collection of tests in various contexts.

  Must be applied to subclasses of `parameterized.TestCase` (from
  `absl/testing`), or a method of such a subclass.

  When applied to a test method, this decorator results in the replacement of
  that method with a collection of new test methods, each executed under a
  different set of context managers that control some aspect of the execution
  model. This decorator generates three test scenario combinations:

    1. Eager mode with `tf.function` decorations enabled
    2. Eager mode with `tf.function` decorations disabled
    3. Graph mode (eveything)

  When applied to a test class, all the methods in the class are affected.

  Args:
    test_class_or_method: the `TestCase` class or method to decorate.

  Returns:
    decorator: A generated TF `test_combinations` decorator, or if
    `test_class_or_method` is not `None`, the generated decorator applied to
    that function.
  """
  decorator = test_combinations.generate(
      (test_combinations.combine(mode='graph',
                                 tf_function='enabled') +
       test_combinations.combine(mode='eager',
                                 tf_function=['enabled', 'disabled'])),
      test_combinations=[
          combinations.EagerGraphCombination(),
          ExecuteFunctionsEagerlyCombination(),
      ])

  if test_class_or_method:
    return decorator(test_class_or_method)
  return decorator
Ejemplo n.º 4
0
class TestCombinationsTest(test_case.TestCase, parameterized.TestCase):

    #
    # These tests check that the generated names are as expected.
    #
    def test_generated_test_case_names(self):
        expected_test_names = [
            'test_something_test_mode_eager_tffunction_disabled',
            'test_something_test_mode_eager_tffunction_enabled',
            'test_something_test_mode_graph_tffunction_enabled',
        ]

        for expected_test_name in expected_test_names:
            self.assertIn(expected_test_name, dir(PretendTestCaseClass))

    def test_generated_parameterized_test_case_names(self):
        expected_test_names = [
            'test_something_p123_test_mode_eager_tffunction_disabled',
            'test_something_p123_test_mode_eager_tffunction_enabled',
            'test_something_p123_test_mode_graph_tffunction_enabled',
        ]

        for expected_test_name in expected_test_names:
            self.assertIn(expected_test_name,
                          dir(PretendParameterizedTestCaseClass))

    def test_generated_graph_and_eager_test_case_names(self):
        expected_test_names = [
            'test_something_test_mode_eager',
            'test_something_test_mode_eager',
            'test_something_test_mode_graph',
        ]

        for expected_test_name in expected_test_names:
            self.assertIn(expected_test_name,
                          dir(PretendTestCaseClassGraphAndEagerOnly))

    #
    # These tests ensure that the test generators do what they say on the tin.
    #
    @tf_test_combinations.generate(
        tf_test_combinations.combine(mode='graph'),
        test_combinations=[tf_combinations.EagerGraphCombination()])
    def test_graph_mode_combination(self):
        self.assertFalse(context.executing_eagerly())

    @tf_test_combinations.generate(
        tf_test_combinations.combine(mode='eager'),
        test_combinations=[tf_combinations.EagerGraphCombination()])
    def test_eager_mode_combination(self):
        self.assertTrue(context.executing_eagerly())

    @tf_test_combinations.generate(
        tf_test_combinations.combine(tf_function='enabled'),
        test_combinations=[
            test_combinations.ExecuteFunctionsEagerlyCombination()
        ])
    def test_tf_function_enabled_mode_combination(self):
        self.assertFalse(def_function.RUN_FUNCTIONS_EAGERLY)

    @tf_test_combinations.generate(
        tf_test_combinations.combine(tf_function='disabled'),
        test_combinations=[
            test_combinations.ExecuteFunctionsEagerlyCombination()
        ])
    def test_tf_function_disabled_mode_combination(self):
        self.assertTrue(def_function.RUN_FUNCTIONS_EAGERLY)