コード例 #1
0
ファイル: easy_profile_asgi.py プロジェクト: lab-grid/flow
class EasyProfileMiddleware(BaseHTTPMiddleware):
    """This middleware prints the number of database queries for each HTTP
    request and can be applied as a WSGI server middleware.
    :param app: WSGI application server
    :param sqlalchemy.engine.base.Engine engine: sqlalchemy database engine
    :param Reporter reporter: reporter instance
    :param list exclude_path: a list of regex patterns for excluding requests
    :param int min_time: minimal queries duration to logging
    :param int min_query_count: minimal queries count to logging
    """
    def __init__(self,
                 app,
                 engine=None,
                 reporter=None,
                 exclude_path=None,
                 min_time=0,
                 min_query_count=1):
        if reporter:
            if not isinstance(reporter, Reporter):
                raise TypeError("reporter must be inherited from 'Reporter'")
            self.reporter = reporter
        else:
            self.reporter = StreamReporter()

        self.app = app
        self.engine = engine
        self.exclude_path = exclude_path or []
        self.min_time = min_time
        self.min_query_count = min_query_count

    async def dispatch_func(self, request: Request, call_next):
        profiler = SessionProfiler(self.engine)
        path = request.url.path
        if not self._ignore_request(path):
            if request.method:
                path = "{0} {1}".format(request.method, path)
            try:
                with profiler:
                    return await call_next(request)
            finally:
                self._report_stats(path, profiler.stats)
        return await call_next(request)

    def _ignore_request(self, path):
        """Check to see if we should ignore the request."""
        return any(re.match(pattern, path) for pattern in self.exclude_path)

    def _report_stats(self, path, stats):
        if (stats["total"] >= self.min_query_count
                and stats["duration"] >= self.min_time):
            self.reporter.report(path, stats)
コード例 #2
0
    def test_report(self):
        dest = mock.Mock()
        reporter = StreamReporter(colorized=False, file=dest)
        reporter.report("test", expected_table_stats)

        expected_output = "\ntest"
        expected_output += expected_table

        total = expected_table_stats["total"]
        duration = expected_table_stats["duration"]
        summary = "\nTotal queries: {0} in {1:.3}s\n".format(total, duration)
        expected_output += summary

        actual_output = dest.write.call_args[0][0]
        self.assertRegexpMatches(actual_output, expected_output)

        for statement, count in expected_table_stats["duplicates"].items():
            statement = sqlparse.format(
                statement, reindent=True, keyword_case="upper"
            )
            text = "\nRepeated {0} times:\n{1}\n".format(count + 1, statement)
            self.assertRegexpMatches(actual_output, text)