class EasyProfileMiddleware(BaseHTTPMiddleware): """This middleware prints the number of database queries for each HTTP request and can be applied as a WSGI server middleware. :param app: WSGI application server :param sqlalchemy.engine.base.Engine engine: sqlalchemy database engine :param Reporter reporter: reporter instance :param list exclude_path: a list of regex patterns for excluding requests :param int min_time: minimal queries duration to logging :param int min_query_count: minimal queries count to logging """ def __init__(self, app, engine=None, reporter=None, exclude_path=None, min_time=0, min_query_count=1): if reporter: if not isinstance(reporter, Reporter): raise TypeError("reporter must be inherited from 'Reporter'") self.reporter = reporter else: self.reporter = StreamReporter() self.app = app self.engine = engine self.exclude_path = exclude_path or [] self.min_time = min_time self.min_query_count = min_query_count async def dispatch_func(self, request: Request, call_next): profiler = SessionProfiler(self.engine) path = request.url.path if not self._ignore_request(path): if request.method: path = "{0} {1}".format(request.method, path) try: with profiler: return await call_next(request) finally: self._report_stats(path, profiler.stats) return await call_next(request) def _ignore_request(self, path): """Check to see if we should ignore the request.""" return any(re.match(pattern, path) for pattern in self.exclude_path) def _report_stats(self, path, stats): if (stats["total"] >= self.min_query_count and stats["duration"] >= self.min_time): self.reporter.report(path, stats)
def __init__(self, app, engine=None, reporter=None, exclude_path=None, min_time=0, min_query_count=1): if reporter: if not isinstance(reporter, Reporter): raise TypeError("reporter must be inherited from 'Reporter'") self.reporter = reporter else: self.reporter = StreamReporter() self.app = app self.engine = engine self.exclude_path = exclude_path or [] self.min_time = min_time self.min_query_count = min_query_count
def test_report(self): dest = mock.Mock() reporter = StreamReporter(colorized=False, file=dest) reporter.report("test", expected_table_stats) expected_output = "\ntest" expected_output += expected_table total = expected_table_stats["total"] duration = expected_table_stats["duration"] summary = "\nTotal queries: {0} in {1:.3}s\n".format(total, duration) expected_output += summary actual_output = dest.write.call_args[0][0] self.assertRegexpMatches(actual_output, expected_output) for statement, count in expected_table_stats["duplicates"].items(): statement = sqlparse.format( statement, reindent=True, keyword_case="upper" ) text = "\nRepeated {0} times:\n{1}\n".format(count + 1, statement) self.assertRegexpMatches(actual_output, text)
def test_initialization(self): mocked_file = mock.Mock() reporter = StreamReporter( medium=1, high=2, file=mocked_file, colorized=False, display_duplicates=0 ) self.assertEqual(reporter._medium, 1) self.assertEqual(reporter._high, 2) self.assertEqual(reporter._file, mocked_file) self.assertFalse(reporter._colorized) self.assertEqual(reporter._display_duplicates, 0)
def test__info_line_on_medium(self): with mock.patch.object(StreamReporter, "_colorize") as mocked: reporter = StreamReporter() reporter._info_line("test", reporter._medium + 1) mocked.assert_called_with("test", ["bold"], fg="yellow")
def test__colorize_on_activated(self): with mock.patch("easy_profile.reporters.colorize") as mocked: reporter = StreamReporter(colorized=True) reporter._colorize("test") mocked.assert_called()
def test_initialization_error(self): with self.assertRaises(ValueError): StreamReporter(medium=100, high=50)
def test_initialization_default(self): reporter = StreamReporter() self.assertEqual(reporter._medium, 50) self.assertEqual(reporter._high, 100) self.assertTrue(reporter._colorized) self.assertEqual(reporter._display_duplicates, 5)
def test_stats_table_change_sep(self): sep = "+" reporter = StreamReporter(colorized=False) actual_table = reporter.stats_table(expected_table_stats, sep=sep) expected = expected_table.replace("|", sep) self.assertEqual(actual_table.strip(), expected.strip())
def test_stats_table(self): reporter = StreamReporter(colorized=False) actual_table = reporter.stats_table(expected_table_stats) self.assertEqual(actual_table.strip(), expected_table.strip())