async def __call__(self, receive, send): response = PlainTextResponse("OK", status_code=200) await response(receive, send) raise HTTPException(status_code=406) router = Router(routes=[ Route("/runtime_error", endpoint=raise_runtime_error), Route("/not_acceptable", endpoint=not_acceptable), Route("/not_modified", endpoint=not_modified), Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse), WebSocketRoute("/runtime_error", endpoint=raise_runtime_error), ]) app = ExceptionMiddleware(router) client = TestClient(app) def test_server_error(): with pytest.raises(RuntimeError): response = client.get("/runtime_error") allow_500_client = TestClient(app, raise_server_exceptions=False) response = allow_500_client.get("/runtime_error") assert response.status_code == 500 assert response.text == "Internal Server Error" def test_debug_enabled(): app = ExceptionMiddleware(router)
def __init__( self, *, debug=False, title=None, version=None, description=None, terms_of_service=None, contact=None, license=None, openapi=None, openapi_route="/schema.yml", static_dir="static", static_route="/static", templates_dir="templates", auto_escape=True, secret_key=DEFAULT_SECRET_KEY, enable_hsts=False, docs_route=None, cors=False, cors_params=DEFAULT_CORS_PARAMS, allowed_hosts=None, ): self.background = BackgroundQueue() self.secret_key = secret_key self.router = Router() if static_dir is not None: if static_route is None: static_route = static_dir static_dir = Path(os.path.abspath(static_dir)) self.static_dir = static_dir self.static_route = static_route self.hsts_enabled = enable_hsts self.cors = cors self.cors_params = cors_params self.debug = debug if not allowed_hosts: # if not debug: # raise RuntimeError( # "You need to specify `allowed_hosts` when debug is set to False" # ) allowed_hosts = ["*"] self.allowed_hosts = allowed_hosts if self.static_dir is not None: os.makedirs(self.static_dir, exist_ok=True) if self.static_dir is not None: self.mount(self.static_route, self.static_app) self.formats = get_formats() # Cached requests session. self._session = None self.default_endpoint = None self.app = ExceptionMiddleware(self.router, debug=debug) self.add_middleware(GZipMiddleware) if self.hsts_enabled: self.add_middleware(HTTPSRedirectMiddleware) self.add_middleware(TrustedHostMiddleware, allowed_hosts=self.allowed_hosts) if self.cors: self.add_middleware(CORSMiddleware, **self.cors_params) self.add_middleware(ServerErrorMiddleware, debug=debug) self.add_middleware(SessionMiddleware, secret_key=self.secret_key) if openapi or docs_route: self.openapi = OpenAPISchema( app=self, title=title, version=version, openapi=openapi, docs_route=docs_route, description=description, terms_of_service=terms_of_service, contact=contact, license=license, openapi_route=openapi_route, static_route=static_route, ) # TODO: Update docs for templates self.templates = Templates(directory=templates_dir) self.requests = ( self.session() ) #: A Requests session that is connected to the ASGI app.
import uvicorn from starlette.exceptions import ExceptionMiddleware, HTTPException from starlette.responses import JSONResponse from starlette.routing import Router, Path, PathPrefix from starlette.middleware.cors import CORSMiddleware # this isn't currently working with starlette 0.3.6 on PyPI, but you can import from github. from demo.apps import homepage, chat app = Router([ Path('/', app=homepage.app, methods=['GET']), PathPrefix('/chat', app=chat.app), ]) app = CORSMiddleware(app, allow_origins=['*']) app = ExceptionMiddleware(app) def error_handler(request, exc): return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) app.add_exception_handler(HTTPException, error_handler) if __name__ == '__main__': uvicorn.run(app, host='0.0.0.0', port=8000)
class Starlette: def __init__(self, debug: bool = False, routes: typing.List[BaseRoute] = None) -> None: self._debug = debug self.router = Router(routes) self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) self.error_middleware = ServerErrorMiddleware( self.exception_middleware, debug=debug) @property def routes(self) -> typing.List[BaseRoute]: return self.router.routes @property def debug(self) -> bool: return self._debug @debug.setter def debug(self, value: bool) -> None: self._debug = value self.exception_middleware.debug = value self.error_middleware.debug = value def on_event(self, event_type: str) -> typing.Callable: return self.router.lifespan.on_event(event_type) def mount(self, path: str, app: ASGIApp, name: str = None) -> None: self.router.mount(path, app=app, name=name) def host(self, host: str, app: ASGIApp, name: str = None) -> None: self.router.host(host, app=app, name=name) def add_middleware(self, middleware_class: type, **kwargs: typing.Any) -> None: self.error_middleware.app = middleware_class(self.error_middleware.app, **kwargs) def add_exception_handler( self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], handler: typing.Callable, ) -> None: if exc_class_or_status_code in (500, Exception): self.error_middleware.handler = handler else: self.exception_middleware.add_exception_handler( exc_class_or_status_code, handler) def add_event_handler(self, event_type: str, func: typing.Callable) -> None: self.router.lifespan.add_event_handler(event_type, func) def add_route( self, path: str, route: typing.Callable, methods: typing.List[str] = None, name: str = None, include_in_schema: bool = True, ) -> None: self.router.add_route(path, route, methods=methods, name=name, include_in_schema=include_in_schema) def add_websocket_route(self, path: str, route: typing.Callable, name: str = None) -> None: self.router.add_websocket_route(path, route, name=name) def exception_handler( self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]] ) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.add_exception_handler(exc_class_or_status_code, func) return func return decorator def route( self, path: str, methods: typing.List[str] = None, name: str = None, include_in_schema: bool = True, ) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.router.add_route( path, func, methods=methods, name=name, include_in_schema=include_in_schema, ) return func return decorator def websocket_route(self, path: str, name: str = None) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.router.add_websocket_route(path, func, name=name) return func return decorator def middleware(self, middleware_type: str) -> typing.Callable: assert (middleware_type == "http" ), 'Currently only middleware("http") is supported.' def decorator(func: typing.Callable) -> typing.Callable: self.add_middleware(BaseHTTPMiddleware, dispatch=func) return func return decorator def url_path_for(self, name: str, **path_params: str) -> URLPath: return self.router.url_path_for(name, **path_params) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: scope["app"] = self await self.error_middleware(scope, receive, send)
def create_admin( tables: t.Sequence[t.Type[Table]], auth_table: t.Type[BaseUser] = BaseUser, session_table: t.Type[SessionsBase] = SessionsBase, session_expiry: timedelta = timedelta(hours=1), max_session_expiry: timedelta = timedelta(days=7), increase_expiry: t.Optional[timedelta] = timedelta(minutes=20), page_size: int = 15, read_only: bool = False, rate_limit_provider: t.Optional[RateLimitProvider] = None, production: bool = False, site_name: str = "Piccolo Admin", auto_include_related: bool = True, allowed_hosts: t.Sequence[str] = [], ): """ :param tables: Each of the tables will be added to the admin. :param auth_table: Either a BaseUser, or BaseUser subclass table, which is used for fetching users. :param session_table: Either a SessionBase, or SessionBase subclass table, which is used for storing and querying session tokens. :param session_expiry: How long a session is valid for. :param max_session_expiry: The maximum time a session is valid for, taking into account any refreshes using `increase_expiry`. :param increase_expiry: If set, the `session_expiry` will be increased by this amount if it's close to expiry. :param page_size: The admin API paginates content - this sets the default number of results on each page. :param read_only: If True, all non auth endpoints only respond to GET requests - the admin can still be viewed, and the data can be filtered. Useful for creating online demos. :param rate_limit_provider: Rate limiting middleware is used to protect the login endpoint against brute force attack. If not set, an InMemoryLimitProvider will be configured with reasonable defaults. :param production: If True, the admin will enforce stronger security - for example, the cookies used will be secure, meaning they are only sent over HTTPS. :param site_name: Specify a different site name in the admin UI (default Piccolo Admin). :param auto_include_related: If a table has foreign keys to other tables, those tables will also be included in the admin by default, if not already specified. Otherwise the admin won't work as expected. :param allowed_hosts: This is used by the CSRF middleware as an additional layer of protection when the admin is run under HTTPS. It must be a sequence of strings, such as ['my_site.com']. """ if auto_include_related: tables = get_all_tables(tables) return ExceptionMiddleware( CSRFMiddleware( AdminRouter( *tables, auth_table=auth_table, session_table=session_table, session_expiry=session_expiry, max_session_expiry=max_session_expiry, increase_expiry=increase_expiry, page_size=page_size, read_only=read_only, rate_limit_provider=rate_limit_provider, production=production, site_name=site_name, ), allowed_hosts=allowed_hosts, ))
class Starlette: def __init__(self, debug: bool = False, template_directory: str = None) -> None: self._debug = debug self.router = Router() self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) self.error_middleware = ServerErrorMiddleware( self.exception_middleware, debug=debug) self.lifespan_middleware = LifespanMiddleware(self.error_middleware) self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator] self.template_env = self.load_template_env(template_directory) @property def routes(self) -> typing.List[BaseRoute]: return self.router.routes @property def debug(self) -> bool: return self._debug @debug.setter def debug(self, value: bool) -> None: self._debug = value self.exception_middleware.debug = value self.error_middleware.debug = value def load_template_env(self, template_directory: str = None) -> typing.Any: if template_directory is None: return None # Import jinja2 lazily. import jinja2 @jinja2.contextfunction def url_for(context: dict, name: str, **path_params: typing.Any) -> str: request = context["request"] return request.url_for(name, **path_params) loader = jinja2.FileSystemLoader(str(template_directory)) env = jinja2.Environment(loader=loader, autoescape=True) env.globals["url_for"] = url_for return env def get_template(self, name: str) -> typing.Any: return self.template_env.get_template(name) @property def schema(self) -> dict: assert self.schema_generator is not None return self.schema_generator.get_schema(self.routes) def on_event(self, event_type: str) -> typing.Callable: return self.lifespan_middleware.on_event(event_type) def mount(self, path: str, app: ASGIApp, name: str = None) -> None: self.router.mount(path, app=app, name=name) def host(self, host: str, app: ASGIApp, name: str = None) -> None: self.router.host(host, app=app, name=name) def add_middleware(self, middleware_class: type, **kwargs: typing.Any) -> None: self.error_middleware.app = middleware_class(self.error_middleware.app, **kwargs) def add_exception_handler( self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], handler: typing.Callable, ) -> None: if exc_class_or_status_code in (500, Exception): self.error_middleware.handler = handler else: self.exception_middleware.add_exception_handler( exc_class_or_status_code, handler) def add_event_handler(self, event_type: str, func: typing.Callable) -> None: self.lifespan_middleware.add_event_handler(event_type, func) def add_route( self, path: str, route: typing.Callable, methods: typing.List[str] = None, name: str = None, include_in_schema: bool = True, ) -> None: self.router.add_route(path, route, methods=methods, name=name, include_in_schema=include_in_schema) def add_websocket_route(self, path: str, route: typing.Callable, name: str = None) -> None: self.router.add_websocket_route(path, route, name=name) def exception_handler( self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]] ) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.add_exception_handler(exc_class_or_status_code, func) return func return decorator def route( self, path: str, methods: typing.List[str] = None, name: str = None, include_in_schema: bool = True, ) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.router.add_route( path, func, methods=methods, name=name, include_in_schema=include_in_schema, ) return func return decorator def websocket_route(self, path: str, name: str = None) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.router.add_websocket_route(path, func, name=name) return func return decorator def middleware(self, middleware_type: str) -> typing.Callable: assert (middleware_type == "http" ), 'Currently only middleware("http") is supported.' def decorator(func: typing.Callable) -> typing.Callable: self.add_middleware(BaseHTTPMiddleware, dispatch=func) return func return decorator def url_path_for(self, name: str, **path_params: str) -> URLPath: return self.router.url_path_for(name, **path_params) def __call__(self, scope: Scope) -> ASGIInstance: scope["app"] = self return self.lifespan_middleware(scope)
class Starlette: def __init__(self, debug=False) -> None: self.router = Router(routes=[]) self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) @property def debug(self) -> bool: return self.exception_middleware.debug @debug.setter def debug(self, value: bool) -> None: self.exception_middleware.debug = value def mount(self, path: str, app: ASGIApp, methods=None) -> None: prefix = PathPrefix(path, app=app, methods=methods) self.router.routes.append(prefix) def add_exception_handler(self, exc_class: type, handler) -> None: self.exception_middleware.add_exception_handler(exc_class, handler) def add_route(self, path: str, route, methods=None) -> None: if not inspect.isclass(route): route = request_response(route) if methods is None: methods = ["GET"] instance = Path(path, route, protocol="http", methods=methods) self.router.routes.append(instance) def add_websocket_route(self, path: str, route) -> None: if not inspect.isclass(route): route = websocket_session(route) instance = Path(path, route, protocol="websocket") self.router.routes.append(instance) def exception_handler(self, exc_class: type): def decorator(func): self.add_exception_handler(exc_class, func) return func return decorator def route(self, path: str, methods=None): def decorator(func): self.add_route(path, func, methods=methods) return func return decorator def websocket_route(self, path: str): def decorator(func): self.add_websocket_route(path, func) return func return decorator def __call__(self, scope: Scope) -> ASGIInstance: scope["app"] = self return self.exception_middleware(scope)
def __init__(self, debug=False) -> None: self.router = Router(routes=[]) self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
class Starlette: def __init__(self, debug: bool = False) -> None: self.router = Router() self.lifespan_handler = LifespanHandler() self.app = self.router self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) @property def routes(self) -> typing.List[BaseRoute]: return self.router.routes @property def debug(self) -> bool: return self.exception_middleware.debug @debug.setter def debug(self, value: bool) -> None: self.exception_middleware.debug = value def on_event(self, event_type: str) -> typing.Callable: return self.lifespan_handler.on_event(event_type) def mount(self, path: str, app: ASGIApp) -> None: self.router.mount(path, app=app) def add_middleware(self, middleware_class: type, **kwargs: typing.Any) -> None: self.exception_middleware.app = middleware_class(self.app, **kwargs) def add_exception_handler(self, exc_class: type, handler: typing.Callable) -> None: self.exception_middleware.add_exception_handler(exc_class, handler) def add_event_handler(self, event_type: str, func: typing.Callable) -> None: self.lifespan_handler.add_event_handler(event_type, func) def add_route( self, path: str, route: typing.Callable, methods: typing.List[str] = None ) -> None: self.router.add_route(path, route, methods=methods) def add_graphql_route(self, path: str, schema: typing.Any) -> None: self.router.add_graphql_route(path, schema) def add_websocket_route(self, path: str, route: typing.Callable) -> None: self.router.add_websocket_route(path, route) def exception_handler(self, exc_class: type) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.add_exception_handler(exc_class, func) return func return decorator def route(self, path: str, methods: typing.List[str] = None) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.router.add_route(path, func, methods=methods) return func return decorator def websocket_route(self, path: str) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.router.add_websocket_route(path, func) return func return decorator def url_path_for(self, name: str, **path_params: str) -> URL: return self.router.url_path_for(name, **path_params) def __call__(self, scope: Scope) -> ASGIInstance: scope["app"] = self scope["router"] = self.router if scope["type"] == "lifespan": return self.lifespan_handler(scope) return self.exception_middleware(scope)
class Starlette: def __init__(self, debug: bool = False) -> None: self.router = Router(routes=[]) self.lifespan_handler = LifespanHandler() self.app = self.router self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) self.executor = ThreadPoolExecutor() @property def debug(self) -> bool: return self.exception_middleware.debug @debug.setter def debug(self, value: bool) -> None: self.exception_middleware.debug = value def on_event(self, event_type: str) -> typing.Callable: return self.lifespan_handler.on_event(event_type) def mount(self, path: str, app: ASGIApp, methods: typing.Sequence[str] = None) -> None: prefix = PathPrefix(path, app=app, methods=methods) self.router.routes.append(prefix) def add_middleware(self, middleware_class: type, **kwargs: typing.Any) -> None: self.exception_middleware.app = middleware_class(self.app, **kwargs) def add_exception_handler(self, exc_class: type, handler: typing.Callable) -> None: self.exception_middleware.add_exception_handler(exc_class, handler) def add_route(self, path: str, route: typing.Callable, methods: typing.Sequence[str] = None) -> None: if not inspect.isclass(route): route = request_response(route) if methods is None: methods = ("GET", ) instance = Path(path, route, protocol="http", methods=methods) self.router.routes.append(instance) def add_websocket_route(self, path: str, route: typing.Callable) -> None: if not inspect.isclass(route): route = websocket_session(route) instance = Path(path, route, protocol="websocket") self.router.routes.append(instance) def exception_handler(self, exc_class: type) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.add_exception_handler(exc_class, func) return func return decorator def route(self, path: str, methods: typing.Sequence[str] = None) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.add_route(path, func, methods=methods) return func return decorator def websocket_route(self, path: str) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.add_websocket_route(path, func) return func return decorator def __call__(self, scope: Scope) -> ASGIInstance: scope["app"] = self scope["executor"] = self.executor if scope["type"] == "lifespan": return self.lifespan_handler(scope) return self.exception_middleware(scope)
def __init__(self, debug: bool = False) -> None: self.router = Router() self.lifespan_handler = LifespanHandler() self.app = self.router self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
import os from starlette.applications import Starlette from starlette.exceptions import ExceptionMiddleware from starlette.middleware.wsgi import WSGIMiddleware from starlette.routing import Router, Path, PathPrefix import uvicorn from wsgi import application from .asgi_app import app as asgi_app app = Router([ PathPrefix("/v2", app=asgi_app), PathPrefix("", app=WSGIMiddleware(application)), ]) DEBUG = os.getenv("DJANGO_DEBUG", "True") if DEBUG == "True": app = ExceptionMiddleware(app, debug=True) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)
Route( "/login/", session_login(), name="login", ), Route("/logout/", session_logout(), name="login"), Mount( "/secret/", AuthenticationMiddleware( ProtectedEndpoint, SessionsAuthBackend( admin_only=True, superuser_only=True, active_only=True), ), ), ]) APP = ExceptionMiddleware(ROUTER) ############################################################################### class SessionTestCase(TestCase): credentials = {"username": "******", "password": "******"} wrong_credentials = {"username": "******", "password": "******"} def setUp(self): SessionsBase.create_table().run_sync() BaseUser.create_table().run_sync() def tearDown(self): SessionsBase.alter().drop_table().run_sync() BaseUser.alter().drop_table().run_sync()
def create_admin( tables: t.Sequence[t.Union[t.Type[Table], TableConfig]], forms: t.List[FormConfig] = [], auth_table: t.Optional[t.Type[BaseUser]] = None, session_table: t.Optional[t.Type[SessionsBase]] = None, session_expiry: timedelta = timedelta(hours=1), max_session_expiry: timedelta = timedelta(days=7), increase_expiry: t.Optional[timedelta] = timedelta(minutes=20), page_size: int = 15, read_only: bool = False, rate_limit_provider: t.Optional[RateLimitProvider] = None, production: bool = False, site_name: str = "Piccolo Admin", auto_include_related: bool = True, allowed_hosts: t.Sequence[str] = [], ): """ :param tables: Each of the tables will be added to the admin. :param forms: For each :class:`FormConfig <piccolo_admin.endpoints.FormConfig>` specified, a form will automatically be rendered in the user interface, accessible via the sidebar. :param auth_table: Either a :class:`BaseUser <piccolo.apps.user.tables.BaseUser>`, or ``BaseUser`` subclass table, which is used for fetching users. Defaults to ``BaseUser`` if none if specified. :param session_table: Either a :class:`SessionsBase <piccolo_api.session_auth.tables.SessionsBase>`, or ``SessionsBase`` subclass table, which is used for storing and querying session tokens. Defaults to ``SessionsBase`` if none if specified. :param session_expiry: How long a session is valid for. :param max_session_expiry: The maximum time a session is valid for, taking into account any refreshes using ``increase_expiry``. :param increase_expiry: If set, the ``session_expiry`` will be increased by this amount if it's close to expiry. :param page_size: The admin API paginates content - this sets the default number of results on each page. :param read_only: If ``True``, all non auth endpoints only respond to GET requests - the admin can still be viewed, and the data can be filtered. Useful for creating online demos. :param rate_limit_provider: Rate limiting middleware is used to protect the login endpoint against brute force attack. If not set, an :class:`InMemoryLimitProvider <piccolo_api.rate_limiting.middleware.InMemoryLimitProvider>` will be configured with reasonable defaults. :param production: If ``True``, the admin will enforce stronger security - for example, the cookies used will be secure, meaning they are only sent over HTTPS. :param site_name: Specify a different site name in the admin UI (default ``'Piccolo Admin'``). :param auto_include_related: If a table has foreign keys to other tables, those tables will also be included in the admin by default, if not already specified. Otherwise the admin won't work as expected. :param allowed_hosts: This is used by the :class:`CSRFMiddleware <piccolo_api.csrf.middleware.CSRFMiddleware>` as an additional layer of protection when the admin is run under HTTPS. It must be a sequence of strings, such as ``['my_site.com']``. """ # noqa: E501 auth_table = auth_table or BaseUser session_table = session_table or SessionsBase if auto_include_related: table_config_map: t.Dict[t.Type[Table], t.Optional[TableConfig]] = {} for i in tables: if isinstance(i, TableConfig): table_config_map[i.table_class] = i else: table_config_map[i] = None all_table_classes = get_all_tables(tuple(table_config_map.keys())) all_table_classes_with_configs: t.List[ t.Union[t.Type[Table], TableConfig] ] = [] for i in all_table_classes: table_config = table_config_map.get(i) if table_config: all_table_classes_with_configs.append(table_config) else: all_table_classes_with_configs.append(i) tables = all_table_classes_with_configs return ExceptionMiddleware( CSRFMiddleware( AdminRouter( *tables, forms=forms, auth_table=auth_table, session_table=session_table, session_expiry=session_expiry, max_session_expiry=max_session_expiry, increase_expiry=increase_expiry, page_size=page_size, read_only=read_only, rate_limit_provider=rate_limit_provider, production=production, site_name=site_name, ), allowed_hosts=allowed_hosts, ) )
DEFAULT_COOKIE_NAME, DEFAULT_HEADER_NAME, CSRFMiddleware, ) async def app(scope, receive, send): await send({ "type": "http.response.start", "status": 200, "headers": [[b"content-type", b"text/plain"]], }) await send({"type": "http.response.body", "body": b"Hello, world!"}) WRAPPED_APP = ExceptionMiddleware(CSRFMiddleware(app, allow_form_param=True)) HOST_RESTRICTED_APP = ExceptionMiddleware( CSRFMiddleware(app, allowed_hosts=["foo.com"], allow_form_param=True)) class TestCSRFMiddleware(TestCase): csrf_token = CSRFMiddleware.get_new_token() incorrect_csrf_token = "abc123" def test_get_request(self): """ Make sure a cookie was set. """ client = TestClient(WRAPPED_APP) response = client.get("/")