Пример #1
0
    async def __call__(self, receive, send):
        response = PlainTextResponse("OK", status_code=200)
        await response(receive, send)
        raise HTTPException(status_code=406)


router = Router(routes=[
    Route("/runtime_error", endpoint=raise_runtime_error),
    Route("/not_acceptable", endpoint=not_acceptable),
    Route("/not_modified", endpoint=not_modified),
    Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse),
    WebSocketRoute("/runtime_error", endpoint=raise_runtime_error),
])

app = ExceptionMiddleware(router)
client = TestClient(app)


def test_server_error():
    with pytest.raises(RuntimeError):
        response = client.get("/runtime_error")

    allow_500_client = TestClient(app, raise_server_exceptions=False)
    response = allow_500_client.get("/runtime_error")
    assert response.status_code == 500
    assert response.text == "Internal Server Error"


def test_debug_enabled():
    app = ExceptionMiddleware(router)
Пример #2
0
    def __init__(
        self,
        *,
        debug=False,
        title=None,
        version=None,
        description=None,
        terms_of_service=None,
        contact=None,
        license=None,
        openapi=None,
        openapi_route="/schema.yml",
        static_dir="static",
        static_route="/static",
        templates_dir="templates",
        auto_escape=True,
        secret_key=DEFAULT_SECRET_KEY,
        enable_hsts=False,
        docs_route=None,
        cors=False,
        cors_params=DEFAULT_CORS_PARAMS,
        allowed_hosts=None,
    ):
        self.background = BackgroundQueue()

        self.secret_key = secret_key

        self.router = Router()

        if static_dir is not None:
            if static_route is None:
                static_route = static_dir
            static_dir = Path(os.path.abspath(static_dir))

        self.static_dir = static_dir
        self.static_route = static_route

        self.hsts_enabled = enable_hsts
        self.cors = cors
        self.cors_params = cors_params
        self.debug = debug

        if not allowed_hosts:
            # if not debug:
            #     raise RuntimeError(
            #         "You need to specify `allowed_hosts` when debug is set to False"
            #     )
            allowed_hosts = ["*"]
        self.allowed_hosts = allowed_hosts

        if self.static_dir is not None:
            os.makedirs(self.static_dir, exist_ok=True)

        if self.static_dir is not None:
            self.mount(self.static_route, self.static_app)

        self.formats = get_formats()

        # Cached requests session.
        self._session = None

        self.default_endpoint = None
        self.app = ExceptionMiddleware(self.router, debug=debug)
        self.add_middleware(GZipMiddleware)

        if self.hsts_enabled:
            self.add_middleware(HTTPSRedirectMiddleware)

        self.add_middleware(TrustedHostMiddleware,
                            allowed_hosts=self.allowed_hosts)

        if self.cors:
            self.add_middleware(CORSMiddleware, **self.cors_params)
        self.add_middleware(ServerErrorMiddleware, debug=debug)
        self.add_middleware(SessionMiddleware, secret_key=self.secret_key)

        if openapi or docs_route:
            self.openapi = OpenAPISchema(
                app=self,
                title=title,
                version=version,
                openapi=openapi,
                docs_route=docs_route,
                description=description,
                terms_of_service=terms_of_service,
                contact=contact,
                license=license,
                openapi_route=openapi_route,
                static_route=static_route,
            )

        # TODO: Update docs for templates
        self.templates = Templates(directory=templates_dir)
        self.requests = (
            self.session()
        )  #: A Requests session that is connected to the ASGI app.
Пример #3
0
import uvicorn
from starlette.exceptions import ExceptionMiddleware, HTTPException
from starlette.responses import JSONResponse
from starlette.routing import Router, Path, PathPrefix
from starlette.middleware.cors import CORSMiddleware  # this isn't currently working with starlette 0.3.6 on PyPI, but you can import from github.
from demo.apps import homepage, chat

app = Router([
    Path('/', app=homepage.app, methods=['GET']),
    PathPrefix('/chat', app=chat.app),
])

app = CORSMiddleware(app, allow_origins=['*'])

app = ExceptionMiddleware(app)


def error_handler(request, exc):
    return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)


app.add_exception_handler(HTTPException, error_handler)

if __name__ == '__main__':
    uvicorn.run(app, host='0.0.0.0', port=8000)
Пример #4
0
class Starlette:
    def __init__(self,
                 debug: bool = False,
                 routes: typing.List[BaseRoute] = None) -> None:
        self._debug = debug
        self.router = Router(routes)
        self.exception_middleware = ExceptionMiddleware(self.router,
                                                        debug=debug)
        self.error_middleware = ServerErrorMiddleware(
            self.exception_middleware, debug=debug)

    @property
    def routes(self) -> typing.List[BaseRoute]:
        return self.router.routes

    @property
    def debug(self) -> bool:
        return self._debug

    @debug.setter
    def debug(self, value: bool) -> None:
        self._debug = value
        self.exception_middleware.debug = value
        self.error_middleware.debug = value

    def on_event(self, event_type: str) -> typing.Callable:
        return self.router.lifespan.on_event(event_type)

    def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
        self.router.mount(path, app=app, name=name)

    def host(self, host: str, app: ASGIApp, name: str = None) -> None:
        self.router.host(host, app=app, name=name)

    def add_middleware(self, middleware_class: type,
                       **kwargs: typing.Any) -> None:
        self.error_middleware.app = middleware_class(self.error_middleware.app,
                                                     **kwargs)

    def add_exception_handler(
        self,
        exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
        handler: typing.Callable,
    ) -> None:
        if exc_class_or_status_code in (500, Exception):
            self.error_middleware.handler = handler
        else:
            self.exception_middleware.add_exception_handler(
                exc_class_or_status_code, handler)

    def add_event_handler(self, event_type: str,
                          func: typing.Callable) -> None:
        self.router.lifespan.add_event_handler(event_type, func)

    def add_route(
        self,
        path: str,
        route: typing.Callable,
        methods: typing.List[str] = None,
        name: str = None,
        include_in_schema: bool = True,
    ) -> None:
        self.router.add_route(path,
                              route,
                              methods=methods,
                              name=name,
                              include_in_schema=include_in_schema)

    def add_websocket_route(self,
                            path: str,
                            route: typing.Callable,
                            name: str = None) -> None:
        self.router.add_websocket_route(path, route, name=name)

    def exception_handler(
        self, exc_class_or_status_code: typing.Union[int,
                                                     typing.Type[Exception]]
    ) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.add_exception_handler(exc_class_or_status_code, func)
            return func

        return decorator

    def route(
        self,
        path: str,
        methods: typing.List[str] = None,
        name: str = None,
        include_in_schema: bool = True,
    ) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.router.add_route(
                path,
                func,
                methods=methods,
                name=name,
                include_in_schema=include_in_schema,
            )
            return func

        return decorator

    def websocket_route(self, path: str, name: str = None) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.router.add_websocket_route(path, func, name=name)
            return func

        return decorator

    def middleware(self, middleware_type: str) -> typing.Callable:
        assert (middleware_type == "http"
                ), 'Currently only middleware("http") is supported.'

        def decorator(func: typing.Callable) -> typing.Callable:
            self.add_middleware(BaseHTTPMiddleware, dispatch=func)
            return func

        return decorator

    def url_path_for(self, name: str, **path_params: str) -> URLPath:
        return self.router.url_path_for(name, **path_params)

    async def __call__(self, scope: Scope, receive: Receive,
                       send: Send) -> None:
        scope["app"] = self
        await self.error_middleware(scope, receive, send)
Пример #5
0
def create_admin(
    tables: t.Sequence[t.Type[Table]],
    auth_table: t.Type[BaseUser] = BaseUser,
    session_table: t.Type[SessionsBase] = SessionsBase,
    session_expiry: timedelta = timedelta(hours=1),
    max_session_expiry: timedelta = timedelta(days=7),
    increase_expiry: t.Optional[timedelta] = timedelta(minutes=20),
    page_size: int = 15,
    read_only: bool = False,
    rate_limit_provider: t.Optional[RateLimitProvider] = None,
    production: bool = False,
    site_name: str = "Piccolo Admin",
    auto_include_related: bool = True,
    allowed_hosts: t.Sequence[str] = [],
):
    """
    :param tables:
        Each of the tables will be added to the admin.
    :param auth_table:
        Either a BaseUser, or BaseUser subclass table, which is used for
        fetching users.
    :param session_table:
        Either a SessionBase, or SessionBase subclass table, which is used
        for storing and querying session tokens.
    :param session_expiry:
        How long a session is valid for.
    :param max_session_expiry:
        The maximum time a session is valid for, taking into account any
        refreshes using `increase_expiry`.
    :param increase_expiry:
        If set, the `session_expiry` will be increased by this amount if it's
        close to expiry.
    :param page_size:
        The admin API paginates content - this sets the default number of
        results on each page.
    :param read_only:
        If True, all non auth endpoints only respond to GET requests - the
        admin can still be viewed, and the data can be filtered. Useful for
        creating online demos.
    :param rate_limit_provider:
        Rate limiting middleware is used to protect the login endpoint
        against brute force attack. If not set, an InMemoryLimitProvider
        will be configured with reasonable defaults.
    :param production:
        If True, the admin will enforce stronger security - for example,
        the cookies used will be secure, meaning they are only sent over
        HTTPS.
    :param site_name:
        Specify a different site name in the admin UI (default Piccolo Admin).
    :param auto_include_related:
        If a table has foreign keys to other tables, those tables will also be
        included in the admin by default, if not already specified. Otherwise
        the admin won't work as expected.
    :param allowed_hosts:
        This is used by the CSRF middleware as an additional layer of
        protection when the admin is run under HTTPS. It must be a sequence of
        strings, such as ['my_site.com'].
    """

    if auto_include_related:
        tables = get_all_tables(tables)

    return ExceptionMiddleware(
        CSRFMiddleware(
            AdminRouter(
                *tables,
                auth_table=auth_table,
                session_table=session_table,
                session_expiry=session_expiry,
                max_session_expiry=max_session_expiry,
                increase_expiry=increase_expiry,
                page_size=page_size,
                read_only=read_only,
                rate_limit_provider=rate_limit_provider,
                production=production,
                site_name=site_name,
            ),
            allowed_hosts=allowed_hosts,
        ))
Пример #6
0
class Starlette:
    def __init__(self,
                 debug: bool = False,
                 template_directory: str = None) -> None:
        self._debug = debug
        self.router = Router()
        self.exception_middleware = ExceptionMiddleware(self.router,
                                                        debug=debug)
        self.error_middleware = ServerErrorMiddleware(
            self.exception_middleware, debug=debug)
        self.lifespan_middleware = LifespanMiddleware(self.error_middleware)
        self.schema_generator = None  # type: typing.Optional[BaseSchemaGenerator]
        self.template_env = self.load_template_env(template_directory)

    @property
    def routes(self) -> typing.List[BaseRoute]:
        return self.router.routes

    @property
    def debug(self) -> bool:
        return self._debug

    @debug.setter
    def debug(self, value: bool) -> None:
        self._debug = value
        self.exception_middleware.debug = value
        self.error_middleware.debug = value

    def load_template_env(self, template_directory: str = None) -> typing.Any:
        if template_directory is None:
            return None

        # Import jinja2 lazily.
        import jinja2

        @jinja2.contextfunction
        def url_for(context: dict, name: str,
                    **path_params: typing.Any) -> str:
            request = context["request"]
            return request.url_for(name, **path_params)

        loader = jinja2.FileSystemLoader(str(template_directory))
        env = jinja2.Environment(loader=loader, autoescape=True)
        env.globals["url_for"] = url_for
        return env

    def get_template(self, name: str) -> typing.Any:
        return self.template_env.get_template(name)

    @property
    def schema(self) -> dict:
        assert self.schema_generator is not None
        return self.schema_generator.get_schema(self.routes)

    def on_event(self, event_type: str) -> typing.Callable:
        return self.lifespan_middleware.on_event(event_type)

    def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
        self.router.mount(path, app=app, name=name)

    def host(self, host: str, app: ASGIApp, name: str = None) -> None:
        self.router.host(host, app=app, name=name)

    def add_middleware(self, middleware_class: type,
                       **kwargs: typing.Any) -> None:
        self.error_middleware.app = middleware_class(self.error_middleware.app,
                                                     **kwargs)

    def add_exception_handler(
        self,
        exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
        handler: typing.Callable,
    ) -> None:
        if exc_class_or_status_code in (500, Exception):
            self.error_middleware.handler = handler
        else:
            self.exception_middleware.add_exception_handler(
                exc_class_or_status_code, handler)

    def add_event_handler(self, event_type: str,
                          func: typing.Callable) -> None:
        self.lifespan_middleware.add_event_handler(event_type, func)

    def add_route(
        self,
        path: str,
        route: typing.Callable,
        methods: typing.List[str] = None,
        name: str = None,
        include_in_schema: bool = True,
    ) -> None:
        self.router.add_route(path,
                              route,
                              methods=methods,
                              name=name,
                              include_in_schema=include_in_schema)

    def add_websocket_route(self,
                            path: str,
                            route: typing.Callable,
                            name: str = None) -> None:
        self.router.add_websocket_route(path, route, name=name)

    def exception_handler(
        self, exc_class_or_status_code: typing.Union[int,
                                                     typing.Type[Exception]]
    ) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.add_exception_handler(exc_class_or_status_code, func)
            return func

        return decorator

    def route(
        self,
        path: str,
        methods: typing.List[str] = None,
        name: str = None,
        include_in_schema: bool = True,
    ) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.router.add_route(
                path,
                func,
                methods=methods,
                name=name,
                include_in_schema=include_in_schema,
            )
            return func

        return decorator

    def websocket_route(self, path: str, name: str = None) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.router.add_websocket_route(path, func, name=name)
            return func

        return decorator

    def middleware(self, middleware_type: str) -> typing.Callable:
        assert (middleware_type == "http"
                ), 'Currently only middleware("http") is supported.'

        def decorator(func: typing.Callable) -> typing.Callable:
            self.add_middleware(BaseHTTPMiddleware, dispatch=func)
            return func

        return decorator

    def url_path_for(self, name: str, **path_params: str) -> URLPath:
        return self.router.url_path_for(name, **path_params)

    def __call__(self, scope: Scope) -> ASGIInstance:
        scope["app"] = self
        return self.lifespan_middleware(scope)
Пример #7
0
class Starlette:
    def __init__(self, debug=False) -> None:
        self.router = Router(routes=[])
        self.exception_middleware = ExceptionMiddleware(self.router,
                                                        debug=debug)

    @property
    def debug(self) -> bool:
        return self.exception_middleware.debug

    @debug.setter
    def debug(self, value: bool) -> None:
        self.exception_middleware.debug = value

    def mount(self, path: str, app: ASGIApp, methods=None) -> None:
        prefix = PathPrefix(path, app=app, methods=methods)
        self.router.routes.append(prefix)

    def add_exception_handler(self, exc_class: type, handler) -> None:
        self.exception_middleware.add_exception_handler(exc_class, handler)

    def add_route(self, path: str, route, methods=None) -> None:
        if not inspect.isclass(route):
            route = request_response(route)
            if methods is None:
                methods = ["GET"]

        instance = Path(path, route, protocol="http", methods=methods)
        self.router.routes.append(instance)

    def add_websocket_route(self, path: str, route) -> None:
        if not inspect.isclass(route):
            route = websocket_session(route)

        instance = Path(path, route, protocol="websocket")
        self.router.routes.append(instance)

    def exception_handler(self, exc_class: type):
        def decorator(func):
            self.add_exception_handler(exc_class, func)
            return func

        return decorator

    def route(self, path: str, methods=None):
        def decorator(func):
            self.add_route(path, func, methods=methods)
            return func

        return decorator

    def websocket_route(self, path: str):
        def decorator(func):
            self.add_websocket_route(path, func)
            return func

        return decorator

    def __call__(self, scope: Scope) -> ASGIInstance:
        scope["app"] = self
        return self.exception_middleware(scope)
Пример #8
0
 def __init__(self, debug=False) -> None:
     self.router = Router(routes=[])
     self.exception_middleware = ExceptionMiddleware(self.router,
                                                     debug=debug)
Пример #9
0
class Starlette:
    def __init__(self, debug: bool = False) -> None:
        self.router = Router()
        self.lifespan_handler = LifespanHandler()
        self.app = self.router
        self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)

    @property
    def routes(self) -> typing.List[BaseRoute]:
        return self.router.routes

    @property
    def debug(self) -> bool:
        return self.exception_middleware.debug

    @debug.setter
    def debug(self, value: bool) -> None:
        self.exception_middleware.debug = value

    def on_event(self, event_type: str) -> typing.Callable:
        return self.lifespan_handler.on_event(event_type)

    def mount(self, path: str, app: ASGIApp) -> None:
        self.router.mount(path, app=app)

    def add_middleware(self, middleware_class: type, **kwargs: typing.Any) -> None:
        self.exception_middleware.app = middleware_class(self.app, **kwargs)

    def add_exception_handler(self, exc_class: type, handler: typing.Callable) -> None:
        self.exception_middleware.add_exception_handler(exc_class, handler)

    def add_event_handler(self, event_type: str, func: typing.Callable) -> None:
        self.lifespan_handler.add_event_handler(event_type, func)

    def add_route(
        self, path: str, route: typing.Callable, methods: typing.List[str] = None
    ) -> None:
        self.router.add_route(path, route, methods=methods)

    def add_graphql_route(self, path: str, schema: typing.Any) -> None:
        self.router.add_graphql_route(path, schema)

    def add_websocket_route(self, path: str, route: typing.Callable) -> None:
        self.router.add_websocket_route(path, route)

    def exception_handler(self, exc_class: type) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.add_exception_handler(exc_class, func)
            return func

        return decorator

    def route(self, path: str, methods: typing.List[str] = None) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.router.add_route(path, func, methods=methods)
            return func

        return decorator

    def websocket_route(self, path: str) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.router.add_websocket_route(path, func)
            return func

        return decorator

    def url_path_for(self, name: str, **path_params: str) -> URL:
        return self.router.url_path_for(name, **path_params)

    def __call__(self, scope: Scope) -> ASGIInstance:
        scope["app"] = self
        scope["router"] = self.router
        if scope["type"] == "lifespan":
            return self.lifespan_handler(scope)
        return self.exception_middleware(scope)
Пример #10
0
class Starlette:
    def __init__(self, debug: bool = False) -> None:
        self.router = Router(routes=[])
        self.lifespan_handler = LifespanHandler()
        self.app = self.router
        self.exception_middleware = ExceptionMiddleware(self.router,
                                                        debug=debug)
        self.executor = ThreadPoolExecutor()

    @property
    def debug(self) -> bool:
        return self.exception_middleware.debug

    @debug.setter
    def debug(self, value: bool) -> None:
        self.exception_middleware.debug = value

    def on_event(self, event_type: str) -> typing.Callable:
        return self.lifespan_handler.on_event(event_type)

    def mount(self,
              path: str,
              app: ASGIApp,
              methods: typing.Sequence[str] = None) -> None:
        prefix = PathPrefix(path, app=app, methods=methods)
        self.router.routes.append(prefix)

    def add_middleware(self, middleware_class: type,
                       **kwargs: typing.Any) -> None:
        self.exception_middleware.app = middleware_class(self.app, **kwargs)

    def add_exception_handler(self, exc_class: type,
                              handler: typing.Callable) -> None:
        self.exception_middleware.add_exception_handler(exc_class, handler)

    def add_route(self,
                  path: str,
                  route: typing.Callable,
                  methods: typing.Sequence[str] = None) -> None:
        if not inspect.isclass(route):
            route = request_response(route)
            if methods is None:
                methods = ("GET", )

        instance = Path(path, route, protocol="http", methods=methods)
        self.router.routes.append(instance)

    def add_websocket_route(self, path: str, route: typing.Callable) -> None:
        if not inspect.isclass(route):
            route = websocket_session(route)

        instance = Path(path, route, protocol="websocket")
        self.router.routes.append(instance)

    def exception_handler(self, exc_class: type) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.add_exception_handler(exc_class, func)
            return func

        return decorator

    def route(self,
              path: str,
              methods: typing.Sequence[str] = None) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.add_route(path, func, methods=methods)
            return func

        return decorator

    def websocket_route(self, path: str) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.add_websocket_route(path, func)
            return func

        return decorator

    def __call__(self, scope: Scope) -> ASGIInstance:
        scope["app"] = self
        scope["executor"] = self.executor
        if scope["type"] == "lifespan":
            return self.lifespan_handler(scope)
        return self.exception_middleware(scope)
Пример #11
0
 def __init__(self, debug: bool = False) -> None:
     self.router = Router()
     self.lifespan_handler = LifespanHandler()
     self.app = self.router
     self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
Пример #12
0
import os
from starlette.applications import Starlette
from starlette.exceptions import ExceptionMiddleware
from starlette.middleware.wsgi import WSGIMiddleware
from starlette.routing import Router, Path, PathPrefix
import uvicorn
from wsgi import application
from .asgi_app import app as asgi_app

app = Router([
    PathPrefix("/v2", app=asgi_app),
    PathPrefix("", app=WSGIMiddleware(application)),
])
DEBUG = os.getenv("DJANGO_DEBUG", "True")

if DEBUG == "True":
    app = ExceptionMiddleware(app, debug=True)

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)
Пример #13
0
    Route(
        "/login/",
        session_login(),
        name="login",
    ),
    Route("/logout/", session_logout(), name="login"),
    Mount(
        "/secret/",
        AuthenticationMiddleware(
            ProtectedEndpoint,
            SessionsAuthBackend(
                admin_only=True, superuser_only=True, active_only=True),
        ),
    ),
])
APP = ExceptionMiddleware(ROUTER)

###############################################################################


class SessionTestCase(TestCase):
    credentials = {"username": "******", "password": "******"}
    wrong_credentials = {"username": "******", "password": "******"}

    def setUp(self):
        SessionsBase.create_table().run_sync()
        BaseUser.create_table().run_sync()

    def tearDown(self):
        SessionsBase.alter().drop_table().run_sync()
        BaseUser.alter().drop_table().run_sync()
Пример #14
0
def create_admin(
    tables: t.Sequence[t.Union[t.Type[Table], TableConfig]],
    forms: t.List[FormConfig] = [],
    auth_table: t.Optional[t.Type[BaseUser]] = None,
    session_table: t.Optional[t.Type[SessionsBase]] = None,
    session_expiry: timedelta = timedelta(hours=1),
    max_session_expiry: timedelta = timedelta(days=7),
    increase_expiry: t.Optional[timedelta] = timedelta(minutes=20),
    page_size: int = 15,
    read_only: bool = False,
    rate_limit_provider: t.Optional[RateLimitProvider] = None,
    production: bool = False,
    site_name: str = "Piccolo Admin",
    auto_include_related: bool = True,
    allowed_hosts: t.Sequence[str] = [],
):
    """
    :param tables:
        Each of the tables will be added to the admin.
    :param forms:
        For each :class:`FormConfig <piccolo_admin.endpoints.FormConfig>`
        specified, a form will automatically be rendered in the user interface,
        accessible via the sidebar.
    :param auth_table:
        Either a :class:`BaseUser <piccolo.apps.user.tables.BaseUser>`, or
        ``BaseUser`` subclass table, which is used for fetching users.
        Defaults to ``BaseUser`` if none if specified.
    :param session_table:
        Either a :class:`SessionsBase <piccolo_api.session_auth.tables.SessionsBase>`,
        or ``SessionsBase`` subclass table, which is used for storing and
        querying session tokens. Defaults to ``SessionsBase`` if none if
        specified.
    :param session_expiry:
        How long a session is valid for.
    :param max_session_expiry:
        The maximum time a session is valid for, taking into account any
        refreshes using ``increase_expiry``.
    :param increase_expiry:
        If set, the ``session_expiry`` will be increased by this amount if it's
        close to expiry.
    :param page_size:
        The admin API paginates content - this sets the default number of
        results on each page.
    :param read_only:
        If ``True``, all non auth endpoints only respond to GET requests - the
        admin can still be viewed, and the data can be filtered. Useful for
        creating online demos.
    :param rate_limit_provider:
        Rate limiting middleware is used to protect the login endpoint
        against brute force attack. If not set, an
        :class:`InMemoryLimitProvider <piccolo_api.rate_limiting.middleware.InMemoryLimitProvider>`
        will be configured with reasonable defaults.
    :param production:
        If ``True``, the admin will enforce stronger security - for example,
        the cookies used will be secure, meaning they are only sent over
        HTTPS.
    :param site_name:
        Specify a different site name in the admin UI (default
        ``'Piccolo Admin'``).
    :param auto_include_related:
        If a table has foreign keys to other tables, those tables will also be
        included in the admin by default, if not already specified. Otherwise
        the admin won't work as expected.
    :param allowed_hosts:
        This is used by the :class:`CSRFMiddleware <piccolo_api.csrf.middleware.CSRFMiddleware>`
        as an additional layer of protection when the admin is run under HTTPS.
        It must be a sequence of strings, such as ``['my_site.com']``.

    """  # noqa: E501
    auth_table = auth_table or BaseUser
    session_table = session_table or SessionsBase

    if auto_include_related:
        table_config_map: t.Dict[t.Type[Table], t.Optional[TableConfig]] = {}

        for i in tables:
            if isinstance(i, TableConfig):
                table_config_map[i.table_class] = i
            else:
                table_config_map[i] = None

        all_table_classes = get_all_tables(tuple(table_config_map.keys()))

        all_table_classes_with_configs: t.List[
            t.Union[t.Type[Table], TableConfig]
        ] = []
        for i in all_table_classes:
            table_config = table_config_map.get(i)
            if table_config:
                all_table_classes_with_configs.append(table_config)
            else:
                all_table_classes_with_configs.append(i)

        tables = all_table_classes_with_configs

    return ExceptionMiddleware(
        CSRFMiddleware(
            AdminRouter(
                *tables,
                forms=forms,
                auth_table=auth_table,
                session_table=session_table,
                session_expiry=session_expiry,
                max_session_expiry=max_session_expiry,
                increase_expiry=increase_expiry,
                page_size=page_size,
                read_only=read_only,
                rate_limit_provider=rate_limit_provider,
                production=production,
                site_name=site_name,
            ),
            allowed_hosts=allowed_hosts,
        )
    )
Пример #15
0
    DEFAULT_COOKIE_NAME,
    DEFAULT_HEADER_NAME,
    CSRFMiddleware,
)


async def app(scope, receive, send):
    await send({
        "type": "http.response.start",
        "status": 200,
        "headers": [[b"content-type", b"text/plain"]],
    })
    await send({"type": "http.response.body", "body": b"Hello, world!"})


WRAPPED_APP = ExceptionMiddleware(CSRFMiddleware(app, allow_form_param=True))
HOST_RESTRICTED_APP = ExceptionMiddleware(
    CSRFMiddleware(app, allowed_hosts=["foo.com"], allow_form_param=True))


class TestCSRFMiddleware(TestCase):

    csrf_token = CSRFMiddleware.get_new_token()
    incorrect_csrf_token = "abc123"

    def test_get_request(self):
        """
        Make sure a cookie was set.
        """
        client = TestClient(WRAPPED_APP)
        response = client.get("/")