class TestCombinationsTest(test_util.TestCase): # # These tests check that the generated names are as expected. # def test_generated_test_case_names(self): expected_test_names = [ 'test_snake_case_name_eager_no_tf_function', 'test_snake_case_name_eager', 'test_snake_case_name_graph', 'testCamelCaseName_eager_no_tf_function', 'testCamelCaseName_eager', 'testCamelCaseName_graph', ] for expected_test_name in expected_test_names: self.assertIn(expected_test_name, dir(PretendTestCaseClass)) def test_generated_parameterized_test_case_names(self): expected_test_names = [ 'test_snake_case_name_p123_eager_no_tf_function', 'test_snake_case_name_p123_eager', 'test_snake_case_name_p123_graph', 'testCamelCaseNamep123_eager_no_tf_function', 'testCamelCaseNamep123_eager', 'testCamelCaseNamep123_graph', ] for expected_test_name in expected_test_names: self.assertIn(expected_test_name, dir(PretendParameterizedTestCaseClass)) def test_generated_graph_and_eager_test_case_names(self): expected_test_names = [ 'test_snake_case_name_eager', 'test_snake_case_name_graph', 'testCamelCaseName_eager', 'testCamelCaseName_graph', ] for expected_test_name in expected_test_names: self.assertIn(expected_test_name, dir(PretendTestCaseClassGraphAndEagerOnly)) # # These tests ensure that the test generators do what they say on the tin. # @test_combinations.generate( test_combinations.combine(mode='graph'), test_combinations=[test_util.EagerGraphCombination()]) def test_graph_mode_combination(self): self.assertFalse(context.executing_eagerly()) @test_combinations.generate( test_combinations.combine(mode='eager'), test_combinations=[test_util.EagerGraphCombination()]) def test_eager_mode_combination(self): self.assertTrue(context.executing_eagerly()) @test_combinations.generate( test_combinations.combine(tf_function=''), test_combinations=[test_util.ExecuteFunctionsEagerlyCombination()]) def test_tf_function_enabled_mode_combination(self): self.assertFalse(tf.config.experimental_functions_run_eagerly()) @test_combinations.generate( test_combinations.combine(tf_function='no_tf_function'), test_combinations=[test_util.ExecuteFunctionsEagerlyCombination()]) def test_tf_function_disabled_mode_combination(self): self.assertTrue(tf.config.experimental_functions_run_eagerly())
class TestCombinationsTest(tfp_test_util.TestCase): # # These tests check that the generated names are as expected. # def test_generated_test_case_names(self): expected_test_names = [ 'test_something_test_mode_eager_tffunction_disabled', 'test_something_test_mode_eager_tffunction_enabled', 'test_something_test_mode_graph_tffunction_enabled', ] for expected_test_name in expected_test_names: self.assertIn(expected_test_name, dir(PretendTestCaseClass)) def test_generated_parameterized_test_case_names(self): expected_test_names = [ 'test_something_p123_test_mode_eager_tffunction_disabled', 'test_something_p123_test_mode_eager_tffunction_enabled', 'test_something_p123_test_mode_graph_tffunction_enabled', ] for expected_test_name in expected_test_names: self.assertIn(expected_test_name, dir(PretendParameterizedTestCaseClass)) def test_generated_graph_and_eager_test_case_names(self): expected_test_names = [ 'test_something_test_mode_eager', 'test_something_test_mode_eager', 'test_something_test_mode_graph', ] for expected_test_name in expected_test_names: self.assertIn(expected_test_name, dir(PretendTestCaseClassGraphAndEagerOnly)) # # These tests ensure that the test generators do what they say on the tin. # @tf_test_combinations.generate( tf_test_combinations.combine(mode='graph'), test_combinations=[tf_combinations.EagerGraphCombination()]) def test_graph_mode_combination(self): self.assertFalse(context.executing_eagerly()) @tf_test_combinations.generate( tf_test_combinations.combine(mode='eager'), test_combinations=[tf_combinations.EagerGraphCombination()]) def test_eager_mode_combination(self): self.assertTrue(context.executing_eagerly()) @tf_test_combinations.generate( tf_test_combinations.combine(tf_function='enabled'), test_combinations=[tfp_test_util.ExecuteFunctionsEagerlyCombination()]) def test_tf_function_enabled_mode_combination(self): self.assertFalse(def_function.RUN_FUNCTIONS_EAGERLY) @tf_test_combinations.generate( tf_test_combinations.combine(tf_function='disabled'), test_combinations=[tfp_test_util.ExecuteFunctionsEagerlyCombination()]) def test_tf_function_disabled_mode_combination(self): self.assertTrue(def_function.RUN_FUNCTIONS_EAGERLY)