class DjangoAutoTestSuite(unittest.TestSuite): """ This test suite configures django settings (which should be in test_settings.py), and starts a test runner. It allows us to run the django tests with setup.py test. """ def __init__(self, *args, **kwargs): self._configure() self.test_runner = DiscoverRunner() tests = self.test_runner.build_suite() super(DjangoAutoTestSuite, self).__init__(tests=tests, *args, **kwargs) self.test_runner.setup_test_environment() self.test_dbs = self.test_runner.setup_databases() def _configure(self): test_settings = importlib.import_module("test_settings") setting_attrs = { attr: getattr(test_settings, attr) for attr in dir(test_settings) if "__" not in attr } if not django.conf.settings.configured: django.conf.settings.configure(**setting_attrs) django.setup() def run(self, result_obj, *args, **kwargs): result = super(DjangoAutoTestSuite, self).run(result_obj, *args, **kwargs) self.test_runner.teardown_databases(self.test_dbs) self.test_runner.teardown_test_environment() return result
def handle(self, *args, **options): # Configure django environment django_test_runner = DiscoverRunner() django_test_runner.setup_test_environment() old_config = django_test_runner.setup_databases() # Run Behave tests behave_main(args=sys.argv[2:]) # Teardown django environment django_test_runner.teardown_databases(old_config) django_test_runner.teardown_test_environment()
class DjangoPlugin(Fixtures): def __init__(self, config): self.config = config self.check_markers() self.configure() self.original_connection_close = {} try: self.live_server_class = import_module(config.option.liveserver_class) except ImportError: liveserver_class = config.option.liveserver_class.split(".") self.live_server_class = getattr(import_module(".".join(liveserver_class[:-1])), liveserver_class[-1]) def check_markers(self): self.skip_trans = False if "not transaction" in self.config.option.markexpr or self.config.option.skip_trans: self.skip_trans = True def configure(self): self.runner = DiscoverRunner(interactive=False, verbosity=self.config.option.verbose) self.runner.setup_test_environment() management.get_commands() # load all commands first is_sqlite = settings.DATABASES.get("default", {}).get("ENGINE", "").endswith("sqlite3") wrap_database() db_postfix = getattr(self.config, "slaveinput", {}).get("slaveid", "") monkey_patch_creation_for_db_reuse(db_postfix if not is_sqlite else None, force=self.config.option.create_db) migrate_db = self.config.option.migrate or self.config.option.create_db can_migrate = "south" in settings.INSTALLED_APPS if can_migrate: from south.management.commands import patch_for_test_db_setup patch_for_test_db_setup() try: self.runner.setup_databases() if migrate_db and can_migrate: management.call_command("migrate", verbosity=self.config.option.verbose) except Exception: raise pytest.UsageError(sys.exc_info()[1]) def pytest_pycollect_makemodule(self, path, parent): return Module(path, parent) @pytest.mark.tryfirst # or trylast as it was ? def pytest_sessionstart(self, session): # turn off debug toolbar to speed up testing middlewares = [] for mid in settings.MIDDLEWARE_CLASSES: if not mid.startswith("debug_toolbar"): middlewares.append(mid) settings.MIDDLEWARE_CLASSES = middlewares for db in connections: conn = connections[db] conn.set_autocommit(False, force_begin_transaction_with_broken_autocommit=True) conn.in_atomic_block = True self.original_connection_close[db] = conn.close conn.close = nop def pytest_sessionfinish(self, session): self.runner.teardown_test_environment() for db in connections: connections[db].in_atomic_block = False transaction.rollback(using=db) connections[db].in_atomic_block = True if self.original_connection_close: connections[db].close = self.original_connection_close[db] @pytest.mark.trylast def pytest_collection_modifyitems(self, items): trans_items = [] non_trans = [] for index, item in enumerate(items): if item.module.has_transactions: trans_items.append(item) else: non_trans.append(item) sorted_trans = [] for module, iterator in groupby(trans_items[:], lambda x: x.module): for item, it in groupby(iterator, lambda x: x.cls and is_transaction_test(x.cls)): sorted_trans.extend(it) sorted_by_modules = non_trans + sorted_trans items[:] = sorted_by_modules def restore_database(self, item, nextitem): for db in connections: management.call_command("flush", verbosity=0, interactive=False, database=db) all(i.setup() for i in item.listchain()) def pytest_runtest_protocol(self, item, nextitem): """Clear database if previous test item was from different module and it was TransactionTestCase. then run setup on all ascending modules """ if item.cls is not None and is_transaction_test(item.cls): if nextitem is None or nextitem.module != item.module: if nextitem is not None: item._request.addfinalizer(lambda: self.restore_database(item, nextitem)) @pytest.mark.tryfirst def pytest_pycollect_makeitem(self, collector, name, obj): """Shadow builtin unittest makeitem with patched class and function """ try: isunit = issubclass(obj, unittest.TestCase) except KeyboardInterrupt: raise except Exception: pass else: if isunit: return SUnitTestCase(name, parent=collector) @pytest.mark.tryfirst def pytest_runtest_setup(self, item): if "transaction" in item.keywords and self.skip_trans: pytest.skip("excluding transaction test") def pytest_runtest_call(self, item, __multicall__): return __multicall__.execute()
class TestRunnerTest(TestCase): def setUp(self): # Simple class that doesn't output to the standard output class StringIOTextRunner(TextTestRunner): def __init__(self, *args, **kwargs): kwargs['stream'] = StringIO() super().__init__(*args, **kwargs) self.test_runner = DiscoverRunner() self.test_runner.test_runner = StringIOTextRunner def tearDown(self): try: os.remove(RequestQueryCountConfig.get_setting('DETAIL_PATH')) except FileNotFoundError: pass try: os.remove(RequestQueryCountConfig.get_setting('SUMMARY_PATH')) except FileNotFoundError: pass def test_empty_test(self): class Test(TestCase): def test_foo(self): pass def test_bar(self): pass self.test_runner.setup_test_environment() self.test_runner.run_suite(TestLoader().loadTestsFromTestCase( testCaseClass=Test) ) self.test_runner.teardown_test_environment() # check for empty tests self.assertIsNotNone(RequestQueryCountManager, 'queries') self.assertIsInstance(RequestQueryCountManager.queries, TestResultQueryContainer) self.assertEqual(RequestQueryCountManager.queries.total, 0) # check if files are generated self.assertTrue(path.exists( RequestQueryCountConfig.get_setting('SUMMARY_PATH')) ) self.assertTrue(path.isfile( RequestQueryCountConfig.get_setting('SUMMARY_PATH')) ) self.assertTrue(path.exists( RequestQueryCountConfig.get_setting('DETAIL_PATH')) ) self.assertTrue(path.isfile( RequestQueryCountConfig.get_setting('DETAIL_PATH')) ) @classmethod def get_id(cls, test_class, method_name): return "{}.{}.{}".format(test_class.__module__, test_class.__qualname__, method_name) def test_runner_include_queries(self): class Test(TestCase): def test_foo(self): self.client.get('/url-1') self.test_runner.run_tests( None, TestLoader().loadTestsFromTestCase(testCaseClass=Test) ) # Assert it ran one test self.assertEqual(len(RequestQueryCountManager.queries.queries_by_testcase), 1) test_foo_id = self.get_id(Test, 'test_foo') self.assertIn(test_foo_id, RequestQueryCountManager.queries.queries_by_testcase) self.assertEqual( RequestQueryCountManager.queries.queries_by_testcase[test_foo_id].total, 1 ) def test_excluded_test(self): class Test(TestCase): @exclude_query_count() def test_foo(self): self.client.get('/url-1') def test_bar(self): self.client.get('/url-1') self.test_runner.run_suite( TestLoader().loadTestsFromTestCase(testCaseClass=Test) ) # Assert test_foo has excluded queries self.assertEqual( RequestQueryCountManager.queries.queries_by_testcase[ self.get_id(Test, 'test_foo')].total, 0 ) # Assert test_bar has some queries self.assertEqual( RequestQueryCountManager.queries.queries_by_testcase[ self.get_id(Test, 'test_bar')].total, 1 ) def test_excluded_class(self): @exclude_query_count() class Test(TestCase): def test_foo(self): self.client.get('path-1') def test_bar(self): self.client.get('path-1') self.test_runner.run_suite( TestLoader().loadTestsFromTestCase(testCaseClass=Test) ) # Assert test_foo has excluded queries self.assertEqual( RequestQueryCountManager.queries.queries_by_testcase[ self.get_id(Test, 'test_foo')].total, 0 ) self.assertEqual( RequestQueryCountManager.queries.queries_by_testcase[ self.get_id(Test, 'test_bar')].total, 0 ) def test_conditional_exclude(self): class Test(TestCase): @exclude_query_count(path='url-2') def test_exclude_path(self): self.client.get('/url-1') self.client.post('/url-2') @exclude_query_count(method='post') def test_exclude_method(self): self.client.get('/url-1') self.client.post('/url-2') @exclude_query_count(count=2) def test_exclude_count(self): self.client.get('/url-1') self.client.post('/url-2') # succesive url are additive self.client.put('/url-3') self.client.put('/url-3') self.client.put('/url-3') self.test_runner.run_suite( TestLoader().loadTestsFromTestCase(testCaseClass=Test) ) self.assertEqual( RequestQueryCountManager.queries.queries_by_testcase[ self.get_id(Test, 'test_exclude_path')].total, 1 ) self.assertEqual( RequestQueryCountManager.queries.queries_by_testcase[ self.get_id(Test, 'test_exclude_method')].total, 1 ) self.assertEqual( RequestQueryCountManager.queries.queries_by_testcase[ self.get_id(Test, 'test_exclude_count')].total, 3 ) def test_nested_method_exclude(self): class Test(TestCase): @exclude_query_count(path='url-1') @exclude_query_count(method='post') @exclude_query_count(path='url-3') def test_foo(self): self.client.get('/url-1') self.client.post('/url-2') self.client.put('/url-3') self.test_runner.run_suite( TestLoader().loadTestsFromTestCase(testCaseClass=Test) ) self.assertEqual( RequestQueryCountManager.queries.queries_by_testcase[ self.get_id(Test, 'test_foo')].total, 0 ) def test_nested_class_method_exclude(self): @exclude_query_count(path='url-1') class Test(TestCase): @exclude_query_count(method='post') def test_foo(self): self.client.get('/url-1') self.client.post('/url-2') self.client.put('/url-3') self.test_runner.run_suite( TestLoader().loadTestsFromTestCase(testCaseClass=Test) ) self.assertEqual( RequestQueryCountManager.queries.queries_by_testcase[ self.get_id(Test, 'test_foo')].total, 1 ) def test_custom_setup_teardown(self): class Test(TestCase): def setUp(self): pass def tearDown(self): pass def test_foo(self): self.client.get('/url-1') self.test_runner.run_suite( TestLoader().loadTestsFromTestCase(testCaseClass=Test) ) self.assertIn( self.get_id(Test, 'test_foo'), RequestQueryCountManager.queries.queries_by_testcase ) self.assertEqual( RequestQueryCountManager.queries.queries_by_testcase[ self.get_id(Test, 'test_foo')].total, 1 )
class TestMiddleWare(TestCase): def setUp(self): # Simple class that doesn't output to the standard output class StringIOTextRunner(TextTestRunner): def __init__(self, *args, **kwargs): kwargs['stream'] = StringIO() super().__init__(*args, **kwargs) self.test_runner = DiscoverRunner() self.test_runner.test_runner = StringIOTextRunner def tearDown(self): try: os.remove(RequestQueryCountConfig.get_setting('DETAIL_PATH')) except FileNotFoundError: pass try: os.remove(RequestQueryCountConfig.get_setting('SUMMARY_PATH')) except FileNotFoundError: pass def test_middleware_called(self): with mock.patch('test_query_counter.middleware.Middleware', new=MagicMock(wraps=Middleware)) as mocked: self.client.get('/url-1') self.assertEqual(mocked.call_count, 1) def test_case_injected_one_test(self): class Test(TestCase): def test_foo(self): self.client.get('/url-1') self.test_runner.setup_test_environment() self.test_runner.run_suite(TestLoader().loadTestsFromTestCase( testCaseClass=Test)) self.test_runner.teardown_test_environment() self.assertEqual(RequestQueryCountManager.queries.total, 1) def test_case_injected_two_tests(self): class Test(TestCase): def test_foo(self): self.client.get('/url-1') def test_bar(self): self.client.get('/url-2') self.test_runner.run_suite( TestLoader().loadTestsFromTestCase(testCaseClass=Test) ) self.assertEqual(RequestQueryCountManager.queries.total, 2) @override_settings(TEST_QUERY_COUNTER={'ENABLE': False}) def test_case_disable_setting(self): class Test(TestCase): def test_foo(self): self.client.get('/url-1') def test_bar(self): self.client.get('/url-2') self.test_runner.run_tests( None, TestLoader().loadTestsFromTestCase(testCaseClass=Test) ) self.assertIsNone(RequestQueryCountManager.queries) @override_settings(TEST_QUERY_COUNTER={'ENABLE': False}) def test_disabled(self): mock_get_response = object() with self.assertRaises(MiddlewareNotUsed): Middleware(mock_get_response) def test_json_exists(self): class Test(TestCase): def test_foo(self): self.client.get('/url-1') self.assertFalse(path.exists( RequestQueryCountConfig.get_setting('DETAIL_PATH')) ) self.test_runner.run_tests( None, TestLoader().loadTestsFromTestCase(testCaseClass=Test) ) self.assertTrue(path.exists( RequestQueryCountConfig.get_setting('DETAIL_PATH')) )