Example #1
0
class TestCombinationsTest(test_util.TestCase):

    #
    # These tests check that the generated names are as expected.
    #
    def test_generated_test_case_names(self):
        expected_test_names = [
            'test_snake_case_name_eager_no_tf_function',
            'test_snake_case_name_eager',
            'test_snake_case_name_graph',
            'testCamelCaseName_eager_no_tf_function',
            'testCamelCaseName_eager',
            'testCamelCaseName_graph',
        ]

        for expected_test_name in expected_test_names:
            self.assertIn(expected_test_name, dir(PretendTestCaseClass))

    def test_generated_parameterized_test_case_names(self):
        expected_test_names = [
            'test_snake_case_name_p123_eager_no_tf_function',
            'test_snake_case_name_p123_eager',
            'test_snake_case_name_p123_graph',
            'testCamelCaseNamep123_eager_no_tf_function',
            'testCamelCaseNamep123_eager',
            'testCamelCaseNamep123_graph',
        ]

        for expected_test_name in expected_test_names:
            self.assertIn(expected_test_name,
                          dir(PretendParameterizedTestCaseClass))

    def test_generated_graph_and_eager_test_case_names(self):
        expected_test_names = [
            'test_snake_case_name_eager',
            'test_snake_case_name_graph',
            'testCamelCaseName_eager',
            'testCamelCaseName_graph',
        ]

        for expected_test_name in expected_test_names:
            self.assertIn(expected_test_name,
                          dir(PretendTestCaseClassGraphAndEagerOnly))

    #
    # These tests ensure that the test generators do what they say on the tin.
    #
    @test_combinations.generate(
        test_combinations.combine(mode='graph'),
        test_combinations=[test_util.EagerGraphCombination()])
    def test_graph_mode_combination(self):
        self.assertFalse(context.executing_eagerly())

    @test_combinations.generate(
        test_combinations.combine(mode='eager'),
        test_combinations=[test_util.EagerGraphCombination()])
    def test_eager_mode_combination(self):
        self.assertTrue(context.executing_eagerly())

    @test_combinations.generate(
        test_combinations.combine(tf_function=''),
        test_combinations=[test_util.ExecuteFunctionsEagerlyCombination()])
    def test_tf_function_enabled_mode_combination(self):
        self.assertFalse(tf.config.experimental_functions_run_eagerly())

    @test_combinations.generate(
        test_combinations.combine(tf_function='no_tf_function'),
        test_combinations=[test_util.ExecuteFunctionsEagerlyCombination()])
    def test_tf_function_disabled_mode_combination(self):
        self.assertTrue(tf.config.experimental_functions_run_eagerly())
Example #2
0
class TestCombinationsTest(tfp_test_util.TestCase):

    #
    # These tests check that the generated names are as expected.
    #
    def test_generated_test_case_names(self):
        expected_test_names = [
            'test_something_test_mode_eager_tffunction_disabled',
            'test_something_test_mode_eager_tffunction_enabled',
            'test_something_test_mode_graph_tffunction_enabled',
        ]

        for expected_test_name in expected_test_names:
            self.assertIn(expected_test_name, dir(PretendTestCaseClass))

    def test_generated_parameterized_test_case_names(self):
        expected_test_names = [
            'test_something_p123_test_mode_eager_tffunction_disabled',
            'test_something_p123_test_mode_eager_tffunction_enabled',
            'test_something_p123_test_mode_graph_tffunction_enabled',
        ]

        for expected_test_name in expected_test_names:
            self.assertIn(expected_test_name,
                          dir(PretendParameterizedTestCaseClass))

    def test_generated_graph_and_eager_test_case_names(self):
        expected_test_names = [
            'test_something_test_mode_eager',
            'test_something_test_mode_eager',
            'test_something_test_mode_graph',
        ]

        for expected_test_name in expected_test_names:
            self.assertIn(expected_test_name,
                          dir(PretendTestCaseClassGraphAndEagerOnly))

    #
    # These tests ensure that the test generators do what they say on the tin.
    #
    @tf_test_combinations.generate(
        tf_test_combinations.combine(mode='graph'),
        test_combinations=[tf_combinations.EagerGraphCombination()])
    def test_graph_mode_combination(self):
        self.assertFalse(context.executing_eagerly())

    @tf_test_combinations.generate(
        tf_test_combinations.combine(mode='eager'),
        test_combinations=[tf_combinations.EagerGraphCombination()])
    def test_eager_mode_combination(self):
        self.assertTrue(context.executing_eagerly())

    @tf_test_combinations.generate(
        tf_test_combinations.combine(tf_function='enabled'),
        test_combinations=[tfp_test_util.ExecuteFunctionsEagerlyCombination()])
    def test_tf_function_enabled_mode_combination(self):
        self.assertFalse(def_function.RUN_FUNCTIONS_EAGERLY)

    @tf_test_combinations.generate(
        tf_test_combinations.combine(tf_function='disabled'),
        test_combinations=[tfp_test_util.ExecuteFunctionsEagerlyCombination()])
    def test_tf_function_disabled_mode_combination(self):
        self.assertTrue(def_function.RUN_FUNCTIONS_EAGERLY)