Esempio n. 1
0
class API:
    """The primary web-service class.

        :param static_dir: The directory to use for static files. Will be created for you if it doesn't already exist.
        :param templates_dir: The directory to use for templates. Will be created for you if it doesn't already exist.
        :param auto_escape: If ``True``, HTML and XML templates will automatically be escaped.
        :param enable_hsts: If ``True``, send all responses to HTTPS URLs.
        :param title: The title of the application (OpenAPI Info Object)
        :param version: The version of the OpenAPI document (OpenAPI Info Object)
        :param description: The description of the OpenAPI document (OpenAPI Info Object)
        :param terms_of_service: A URL to the Terms of Service for the API (OpenAPI Info Object)
        :param contact: The contact dictionary of the application (OpenAPI Contact Object)
        :param license: The license information of the exposed API (OpenAPI License Object)
    """

    status_codes = status_codes

    def __init__(
        self,
        *,
        debug=False,
        title=None,
        version=None,
        description=None,
        terms_of_service=None,
        contact=None,
        license=None,
        openapi=None,
        openapi_route="/schema.yml",
        static_dir="static",
        static_route="/static",
        templates_dir="templates",
        auto_escape=True,
        secret_key=DEFAULT_SECRET_KEY,
        enable_hsts=False,
        docs_route=None,
        cors=False,
        cors_params=DEFAULT_CORS_PARAMS,
        allowed_hosts=None,
    ):
        self.background = BackgroundQueue()

        self.secret_key = secret_key
        self.title = title
        self.version = version
        self.description = description
        self.terms_of_service = terms_of_service
        self.contact = contact
        self.license = license
        self.openapi_version = openapi

        if static_dir is not None:
            if static_route is None:
                static_route = static_dir
            static_dir = Path(os.path.abspath(static_dir))

        self.static_dir = static_dir
        self.static_route = static_route

        self.built_in_templates_dir = Path(
            os.path.abspath(os.path.dirname(__file__) + "/templates"))

        if templates_dir is not None:
            templates_dir = Path(os.path.abspath(templates_dir))

        self.templates_dir = templates_dir or self.built_in_templates_dir

        self.apps = {}
        self.routes = {}
        self.before_requests = {"http": [], "ws": []}
        self.docs_theme = DEFAULT_API_THEME
        self.docs_route = docs_route
        self.schemas = {}
        self.session_cookie = DEFAULT_SESSION_COOKIE

        self.hsts_enabled = enable_hsts
        self.cors = cors
        self.cors_params = cors_params
        self.debug = debug

        if not allowed_hosts:
            # if not debug:
            #     raise RuntimeError(
            #         "You need to specify `allowed_hosts` when debug is set to False"
            #     )
            allowed_hosts = ["*"]
        self.allowed_hosts = allowed_hosts

        # Make the static/templates directory if they don't exist.
        for _dir in (self.static_dir, self.templates_dir):
            if _dir is not None:
                os.makedirs(_dir, exist_ok=True)

        if self.static_dir is not None:
            self.whitenoise = WhiteNoise(application=self._notfound_wsgi_app)
            self.whitenoise.add_files(str(self.static_dir))

            self.whitenoise.add_files(
                (Path(apistar.__file__).parent / "themes" / self.docs_theme /
                 "static").resolve())

            self.mount(self.static_route, self.whitenoise)

        self.formats = get_formats()

        # Cached requests session.
        self._session = None

        if self.openapi_version:
            self.add_route(openapi_route, self.schema_response)

        if self.docs_route:
            self.add_route(self.docs_route, self.docs_response)

        self.default_endpoint = None
        self.app = self.dispatch
        self.add_middleware(GZipMiddleware)

        if self.hsts_enabled:
            self.add_middleware(HTTPSRedirectMiddleware)

        self.add_middleware(TrustedHostMiddleware,
                            allowed_hosts=self.allowed_hosts)

        self.lifespan_handler = Lifespan()

        if self.cors:
            self.add_middleware(CORSMiddleware, **self.cors_params)
        self.add_middleware(ServerErrorMiddleware, debug=debug)

        # Jinja environment
        self.jinja_env = jinja2.Environment(
            loader=jinja2.FileSystemLoader(
                [str(self.templates_dir),
                 str(self.built_in_templates_dir)],
                followlinks=True,
            ),
            autoescape=jinja2.select_autoescape(
                ["html", "xml"] if auto_escape else []),
        )
        self.jinja_values_base = {"api": self}  # Give reference to self.
        self.requests = (
            self.session()
        )  #: A Requests session that is connected to the ASGI app.

    @staticmethod
    def _default_wsgi_app(environ, start_response):
        pass

    @staticmethod
    def _notfound_wsgi_app(environ, start_response):
        start_response("404 NOT FOUND", [("Content-Type", "text/plain")])
        return [b"Not Found."]

    def before_request(self, websocket=False):
        def decorator(f):
            if websocket:
                self.before_requests.setdefault("ws", []).append(f)
            else:
                self.before_requests.setdefault("http", []).append(f)
            return f

        return decorator

    @property
    def before_http_requests(self):
        return self.before_requests.get("http", [])

    @property
    def before_ws_requests(self):
        return self.before_requests.get("ws", [])

    @property
    def _apispec(self):

        info = {}
        if self.description is not None:
            info["description"] = self.description
        if self.terms_of_service is not None:
            info["termsOfService"] = self.terms_of_service
        if self.contact is not None:
            info["contact"] = self.contact
        if self.license is not None:
            info["license"] = self.license

        spec = APISpec(
            title=self.title,
            version=self.version,
            openapi_version=self.openapi_version,
            plugins=[MarshmallowPlugin()],
            info=info,
        )

        for route in self.routes:
            if self.routes[route].description:
                operations = yaml_utils.load_operations_from_docstring(
                    self.routes[route].description)
                spec.path(path=route, operations=operations)

        for name, schema in self.schemas.items():
            spec.components.schema(name, schema=schema)

        return spec

    @property
    def openapi(self):
        return self._apispec.to_yaml()

    def add_middleware(self, middleware_cls, **middleware_config):
        self.app = middleware_cls(self.app, **middleware_config)

    def __call__(self, scope):

        if scope["type"] == "lifespan":
            return self.lifespan_handler(scope)

        path = scope["path"]
        root_path = scope.get("root_path", "")

        # Call into a submounted app, if one exists.
        for path_prefix, app in self.apps.items():
            if path.startswith(path_prefix):
                scope["path"] = path[len(path_prefix):]
                scope["root_path"] = root_path + path_prefix
                try:
                    return app(scope)
                except TypeError:
                    app = WSGIMiddleware(app)
                    return app(scope)

        return self.app(scope)

    def dispatch(self, scope):
        # Call the main dispatcher.
        async def asgi(receive, send):
            nonlocal scope, self

            assert scope["type"] in ("http", "websocket")

            if scope["type"] == "websocket":
                await self._dispatch_ws(scope=scope,
                                        receive=receive,
                                        send=send)
            else:
                req = models.Request(scope, receive=receive, api=self)
                resp = await self._dispatch_http(req,
                                                 scope=scope,
                                                 send=send,
                                                 receive=receive)
                await resp(receive, send)

        return asgi

    async def _dispatch_http(self, req, **options):
        # Set formats on Request object.
        req.formats = self.formats

        # Get the route.
        route = self.path_matches_route(req.url.path)
        route = self.routes.get(route)
        if route:
            resp = models.Response(req=req, formats=self.formats)

            for before_request in self.before_http_requests:
                await self.background(before_request, req=req, resp=resp)

            await self._execute_route(route=route,
                                      req=req,
                                      resp=resp,
                                      **options)
        else:
            resp = models.Response(req=req, formats=self.formats)
            self.default_response(req=req, resp=resp, notfound=True)
        self.default_response(req=req, resp=resp)

        self._prepare_session(resp)

        return resp

    async def _dispatch_ws(self, scope, receive, send):
        ws = WebSocket(scope=scope, receive=receive, send=send)

        route = self.path_matches_route(ws.url.path)
        route = self.routes.get(route)

        if route:
            for before_request in self.before_ws_requests:
                await self.background(before_request, ws=ws)
            await self.background(route.endpoint, ws)
        else:
            await send({"type": "websocket.close", "code": 1000})

    def add_schema(self, name, schema, check_existing=True):
        """Adds a mashmallow schema to the API specification."""
        if check_existing:
            assert name not in self.schemas

        self.schemas[name] = schema

    def schema(self, name, **options):
        """Decorator for creating new routes around function and class definitions.

        Usage::

            from marshmallow import Schema, fields

            @api.schema("Pet")
            class PetSchema(Schema):
                name = fields.Str()

        """
        def decorator(f):
            self.add_schema(name=name, schema=f, **options)
            return f

        return decorator

    def path_matches_route(self, path):
        """Given a path portion of a URL, tests that it matches against any registered route.

        :param path: The path portion of a URL, to test all known routes against.
        """
        for (route, route_object) in self.routes.items():
            if route_object.does_match(path):
                return route

    @property
    def _signer(self):
        return itsdangerous.Signer(self.secret_key)

    def _prepare_session(self, resp):

        if resp.session:
            data = self._signer.sign(
                b64encode(json.dumps(resp.session).encode("utf-8")))
            resp.cookies[self.session_cookie] = data.decode("utf-8")

    @staticmethod
    def no_response(req, resp, **params):
        pass

    async def _execute_route(self, *, route, req, resp, **options):

        params = route.incoming_matches(req.url.path)

        cont = True

        if route.is_function:
            try:
                try:
                    # Run the view.
                    r = self.background(route.endpoint, req, resp, **params)
                    # If it's async, await it.
                    if hasattr(r, "cr_running"):
                        await r
                except TypeError as e:
                    cont = True
            except Exception:
                await self.background(self.default_response,
                                      req,
                                      resp,
                                      error=True)
                raise

        if route.is_class_based or cont:
            try:
                view = route.endpoint(**params)
            except TypeError:
                try:
                    view = route.endpoint()
                except TypeError:
                    view = route.endpoint
                    pass

            # Run on_request first.
            try:
                # Run the view.
                r = getattr(view, "on_request", self.no_response)
                r = self.background(r, req, resp, **params)
                # If it's async, await it.
                if hasattr(r, "send"):
                    await r
            except Exception:
                await self.background(self.default_response,
                                      req,
                                      resp,
                                      error=True)
                raise

            # Then run on_method.
            method = req.method
            try:
                # Run the view.
                r = getattr(view, f"on_{method}", self.no_response)
                r = self.background(r, req, resp, **params)
                # If it's async, await it.
                if hasattr(r, "send"):
                    await r
            except Exception:
                await self.background(self.default_response,
                                      req,
                                      resp,
                                      error=True)
                raise

    def add_event_handler(self, event_type, handler):
        """Adds an event handler to the API.

        :param event_type: A string in ("startup", "shutdown")
        :param handler: The function to run. Can be either a function or a coroutine.
        """

        self.lifespan_handler.add_event_handler(event_type, handler)

    def add_route(
        self,
        route=None,
        endpoint=None,
        *,
        default=False,
        static=False,
        check_existing=True,
        websocket=False,
        before_request=False,
    ):
        """Adds a route to the API.

        :param route: A string representation of the route.
        :param endpoint: The endpoint for the route -- can be a callable, or a class.
        :param default: If ``True``, all unknown requests will route to this view.
        :param static: If ``True``, and no endpoint was passed, render "static/index.html", and it will become a default route.
        :param check_existing: If ``True``, an AssertionError will be raised, if the route is already defined.
        """
        if before_request:
            if websocket:
                self.before_requests.setdefault("ws", []).append(endpoint)
            else:
                self.before_requests.setdefault("http", []).append(endpoint)
            return

        if route is None:
            route = f"/{uuid4().hex}"

        if check_existing:
            assert route not in self.routes

        if static:
            assert self.static_dir is not None
            if not endpoint:
                endpoint = self.static_response
                default = True

        if default:
            self.default_endpoint = endpoint

        self.routes[route] = Route(route, endpoint, websocket=websocket)
        # TODO: A better data structure or sort it once the app is loaded
        self.routes = dict(
            sorted(self.routes.items(), key=lambda item: item[1]._weight()))

    def default_response(self,
                         req=None,
                         resp=None,
                         websocket=False,
                         notfound=False,
                         error=False):
        if websocket:
            return

        if resp.status_code is None:
            resp.status_code = 200

        if self.default_endpoint and notfound:
            self.default_endpoint(req=req, resp=resp)
        else:
            if notfound:
                resp.status_code = status_codes.HTTP_404
                resp.text = "Not found."
            if error:
                resp.status_code = status_codes.HTTP_500
                resp.text = "Application error."

    def docs_response(self, req, resp):
        resp.html = self.docs

    def static_response(self, req, resp):

        assert self.static_dir is not None

        index = (self.static_dir / "index.html").resolve()
        if os.path.exists(index):
            with open(index, "r") as f:
                resp.html = f.read()
        else:
            resp.status_code = status_codes.HTTP_404
            resp.text = "Not found."

    def schema_response(self, req, resp):
        resp.status_code = status_codes.HTTP_200
        resp.headers["Content-Type"] = "application/x-yaml"
        resp.content = self.openapi

    def redirect(self,
                 resp,
                 location,
                 *,
                 set_text=True,
                 status_code=status_codes.HTTP_301):
        """Redirects a given response to a given location.

        :param resp: The Response to mutate.
        :param location: The location of the redirect.
        :param set_text: If ``True``, sets the Redirect body content automatically.
        :param status_code: an `API.status_codes` attribute, or an integer, representing the HTTP status code of the redirect.
        """

        # assert resp.status_code.is_300(status_code)

        resp.status_code = status_code
        if set_text:
            resp.text = f"Redirecting to: {location}"
        resp.headers.update({"Location": location})

    def on_event(self, event_type: str, **args):
        """Decorator for registering functions or coroutines to run at certain events
        Supported events: startup, shutdown

        Usage::

            @api.on_event('startup')
            async def open_database_connection_pool():
                ...

            @api.on_event('shutdown')
            async def close_database_connection_pool():
                ...

        """
        def decorator(func):
            self.add_event_handler(event_type, func, **args)
            return func

        return decorator

    def route(self, route=None, **options):
        """Decorator for creating new routes around function and class definitions.

        Usage::

            @api.route("/hello")
            def hello(req, resp):
                resp.text = "hello, world!"

        """
        def decorator(f):
            self.add_route(route, f, **options)
            return f

        return decorator

    def mount(self, route, app):
        """Mounts an WSGI / ASGI application at a given route.

        :param route: String representation of the route to be used (shouldn't be parameterized).
        :param app: The other WSGI / ASGI app.
        """
        self.apps.update({route: app})

    def session(self, base_url="http://;"):
        """Testing HTTP client. Returns a Requests session object, able to send HTTP requests to the Responder application.

        :param base_url: The URL to mount the connection adaptor to.
        """

        if self._session is None:
            self._session = TestClient(self, base_url=base_url)
        return self._session

    def _route_for(self, endpoint):
        for route_object in self.routes.values():
            if endpoint in (route_object.endpoint, route_object.endpoint_name):
                return route_object

    def url_for(self, endpoint, **params):
        # TODO: Absolute_url
        """Given an endpoint, returns a rendered URL for its route.

        :param endpoint: The route endpoint you're searching for.
        :param params: Data to pass into the URL generator (for parameterized URLs).
        """
        route_object = self._route_for(endpoint)
        if route_object:
            return route_object.url(**params)
        raise ValueError

    def static_url(self, asset):
        """Given a static asset, return its URL path."""
        assert None not in (self.static_dir, self.static_route)
        return f"{self.static_route}/{str(asset)}"

    @property
    def docs(self):

        loader = jinja2.PrefixLoader({
            self.docs_theme:
            jinja2.PackageLoader(
                "apistar", os.path.join("themes", self.docs_theme,
                                        "templates"))
        })
        env = jinja2.Environment(autoescape=True, loader=loader)
        document = apistar.document.Document()
        document.content = yaml.safe_load(self.openapi)

        template = env.get_template("/".join([self.docs_theme, "index.html"]))

        def static_url(asset):
            assert None not in (self.static_dir, self.static_route)
            return f"{self.static_route}/{asset}"

        return template.render(
            document=document,
            langs=["javascript", "python"],
            code_style=None,
            static_url=static_url,
            schema_url="/schema.yml",
        )

    def template(self, name_, **values):
        """Renders the given `jinja2 <http://jinja.pocoo.org/docs/>`_ template, with provided values supplied.

        Note: The current ``api`` instance is by default passed into the view. This is set in the dict ``api.jinja_values_base``.

        :param name_: The filename of the jinja2 template, in ``templates_dir``.
        :param values: Data to pass into the template.
        """
        # Prepopulate values with base
        values = {**self.jinja_values_base, **values}

        template = self.jinja_env.get_template(name_)
        return template.render(**values)

    def template_string(self, s_, **values):
        """Renders the given `jinja2 <http://jinja.pocoo.org/docs/>`_ template string, with provided values supplied.

        Note: The current ``api`` instance is by default passed into the view. This is set in the dict ``api.jinja_values_base``.

        :param s_: The template to use.
        :param values: Data to pass into the template.
        """
        # Prepopulate values with base
        values = {**self.jinja_values_base, **values}

        template = self.jinja_env.from_string(s_)
        return template.render(**values)

    def serve(self, *, address=None, port=None, debug=False, **options):
        """Runs the application with uvicorn. If the ``PORT`` environment
        variable is set, requests will be served on that port automatically to all
        known hosts.

        :param address: The address to bind to.
        :param port: The port to bind to. If none is provided, one will be selected at random.
        :param debug: Run uvicorn server in debug mode.
        :param options: Additional keyword arguments to send to ``uvicorn.run()``.
        """

        if "PORT" in os.environ:
            if address is None:
                address = "0.0.0.0"
            port = int(os.environ["PORT"])

        if address is None:
            address = "127.0.0.1"
        if port is None:
            port = 5042

        def spawn():
            uvicorn.run(self, host=address, port=port, debug=debug, **options)

        spawn()

    def run(self, **kwargs):
        if "debug" not in kwargs:
            kwargs.update({"debug": self.debug})
        self.serve(**kwargs)
Esempio n. 2
0
class App(RoutingMixin, metaclass=DocsMeta):
    """The all-mighty application class.

    This class implements the [ASGI](https://asgi.readthedocs.io) protocol.

    # Example

    ```python
    >>> from bocadillo import App
    >>> app = App()
    ```

    # Parameters
    name (str):
        An optional name for the app.
    static_dir (str):
        The name of the directory containing static files, relative to
        the application entry point. Set to `None` to not serve any static
        files.
        Defaults to `"static"`.
    static_root (str):
        The path prefix for static assets.
        Defaults to `"static"`.
    static_config (dict):
        Extra static files configuration attributes.
        See also #::bocadillo.staticfiles#static.
    allowed_hosts (list of str, optional):
        A list of hosts which the server is allowed to run at.
        If the list contains `"*"`, any host is allowed.
        Defaults to `["*"]`.
    enable_sessions (bool):
        If `True` cookie-based sessions will be enabled if `SECRET_KEY` environment
        variable has nonzero length. Defaults to `False`.
    enable_cors (bool):
        If `True`, Cross Origin Resource Sharing will be configured according
        to `cors_config`. Defaults to `False`.
        See also [CORS](../guides/http/middleware.md#cors).
    cors_config (dict):
        A dictionary of CORS configuration parameters.
        Defaults to `dict(allow_origins=[], allow_methods=["GET"])`.
    enable_hsts (bool):
        If `True`, enable HSTS (HTTP Strict Transport Security) and automatically
        redirect HTTP traffic to HTTPS.
        Defaults to `False`.
        See also [HSTS](../guides/http/middleware.md#hsts).
    enable_gzip (bool):
        If `True`, enable GZip compression and automatically
        compress responses for clients that support it.
        Defaults to `False`.
        See also [GZip](../guides/http/middleware.md#gzip).
    gzip_min_size (int):
        If specified, compress only responses that
        have more bytes than the specified value.
        Defaults to `1024`.
    media_type (str):
        Determines how values given to `res.media` are serialized.
        Can be one of the supported media types.
        Defaults to `"application/json"`.
        See also [Media](../guides/http/media.md).

    # Attributes
    media_handlers (dict):
        The dictionary of media handlers.
        You can access, edit or replace this at will.
    """

    import_string: Optional[str]

    def __new__(cls, *args, **kwargs):
        instance = super().__new__(cls)

        # HACK: get the Python module path where this app was instanciated.
        # This import string is passed to uvicorn in debug mode.
        # See the `.run()` method.
        _, *frames = inspect.stack()
        frame = frames[0]
        instance.import_string = _get_module(frame.filename)

        return instance

    def __init__(
        self,
        name: str = None,
        *,
        static_dir: Optional[str] = "static",
        static_root: Optional[str] = "static",
        static_config: dict = None,
        allowed_hosts: List[str] = None,
        enable_sessions: bool = False,
        enable_cors: bool = False,
        cors_config: dict = None,
        enable_hsts: bool = False,
        enable_gzip: bool = False,
        gzip_min_size: int = 1024,
        media_type: str = CONTENT_TYPE.JSON,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.name = name

        # Debug mode defaults to `False` but it can be set in `.run()`.
        self._debug = False

        # Base ASGI app
        self.asgi = self.dispatch

        # Mounted (children) apps
        self._prefix_to_app: Dict[str, Any] = {}
        self._name_to_prefix_and_app: Dict[str, Tuple[str, App]] = {}
        self._static_apps: Dict[str, WhiteNoise] = {}

        # Static files
        if static_dir is not None:
            if static_root is None:
                static_root = static_dir
            self.mount(static_root, static(static_dir, **(static_config
                                                          or {})))

        # Media
        self.media_handlers = get_default_handlers()
        self._media_type = ""
        self.media_type = media_type

        # HTTP middleware
        self.exception_middleware = HTTPErrorMiddleware(self.http_router,
                                                        debug=self._debug)
        self.server_error_middleware = ServerErrorMiddleware(
            self.exception_middleware,
            handler=error_to_text,
            debug=self._debug)
        self.add_error_handler(HTTPError, error_to_text)

        # Lifespan middleware
        self._lifespan = Lifespan()

        # ASGI middleware
        if allowed_hosts is None:
            allowed_hosts = ["*"]
        self.add_asgi_middleware(TrustedHostMiddleware,
                                 allowed_hosts=allowed_hosts)
        if enable_sessions:
            secret_key = os.environ.get("SECRET_KEY", "")
            if not secret_key:
                raise RuntimeError(
                    "Use of sessions requires SECRET_KEY environment variable")
            self.add_asgi_middleware(SessionMiddleware, secret_key=secret_key)
        if enable_cors:
            cors_config = {**DEFAULT_CORS_CONFIG, **(cors_config or {})}
            self.add_asgi_middleware(CORSMiddleware, **cors_config)
        if enable_hsts:
            self.add_asgi_middleware(HTTPSRedirectMiddleware)
        if enable_gzip:
            self.add_asgi_middleware(GZipMiddleware,
                                     minimum_size=gzip_min_size)

        # Built-in providers.
        self._frozen = False
        self._http_context = create_context_provider("req", "res")

    def _app_providers(self):  # pylint: disable=method-hidden
        if not self._frozen:
            freeze_providers()
            self._frozen = True
            # do nothing on subsequent calls
            self._app_providers = nullcontext
        return nullcontext()

    @property
    @deprecated(
        since="0.13",
        removal="0.14",
        alternative=("create_client", "/api/testing.md#create-client"),
    )
    def client(self):
        return create_client(self)

    @property
    def debug(self) -> bool:
        return self._debug

    @debug.setter
    def debug(self, debug: bool):
        self._debug = debug
        self.exception_middleware.debug = debug
        self.server_error_middleware.debug = debug

    @property
    def media_type(self) -> str:
        """The media type configured when instanciating the application."""
        return self._media_type

    @media_type.setter
    def media_type(self, media_type: str):
        if media_type not in self.media_handlers:
            raise UnsupportedMediaType(media_type,
                                       handlers=self.media_handlers)
        self._media_type = media_type

    def url_for(self, name: str, **kwargs) -> str:
        # Implement route name lookup accross sub-apps.
        try:
            return super().url_for(name, **kwargs)
        except HTTPError as exc:
            app_name, _, name = name.partition(":")

            if not name:
                # No app name given.
                raise exc from None

            return self._url_for_app(app_name, name, **kwargs)

    def _url_for_app(self, app_name: str, name: str, **kwargs) -> str:
        if app_name == self.name:
            # NOTE: this allows to reference this app's routes in
            # both with or without the namespace.
            return self._get_own_url_for(name, **kwargs)

        try:
            prefix, app = self._name_to_prefix_and_app[app_name]
        except KeyError as key_exc:
            raise HTTPError(404) from key_exc
        else:
            return prefix + app.url_for(name, **kwargs)

    def _get_own_url_for(self, name: str, **kwargs) -> str:
        # NOTE: recipes hook into this method to prepend their
        # prefix to the URL.
        return super().url_for(name, **kwargs)

    def mount(self, prefix: str, app: Union["App", ASGIApp, WSGIApp]):
        """Mount another WSGI or ASGI app at the given prefix.

        [WSGI]: https://wsgi.readthedocs.io
        [ASGI]: https://asgi.readthedocs.io

        # Parameters
        prefix (str):
            A path prefix where the app should be mounted, e.g. `"/myapp"`.
        app:
            an object implementing the [WSGI] or [ASGI] protocol.
        """
        if not prefix.startswith("/"):
            prefix = "/" + prefix

        self._prefix_to_app[prefix] = app

        if isinstance(app, App) and app.name is not None:
            self._name_to_prefix_and_app[app.name] = (prefix, app)

        if isinstance(app, WhiteNoise):
            self._static_apps[prefix] = app

    def recipe(self, recipe: "Recipe"):
        """Apply a recipe.

        # Parameters
        recipe:
            a #::bocadillo.recipes#Recipe or #::bocadillo.recipes#RecipeBook
            to be applied to the application.

        # See Also
        - [Recipes](../guides/agnostic/recipes.md)
        """
        recipe.apply(self)

    def add_error_handler(self, exception_cls: Type[_E],
                          handler: ErrorHandler):
        """Register a new error handler.

        # Parameters
        exception_cls (exception class):
            The type of exception that should be handled.
        handler (callable):
            The actual error handler, which is called when an instance of
            `exception_cls` is caught.
            Should accept a request, response and exception parameters.
        """
        self.exception_middleware.add_exception_handler(exception_cls, handler)

    def error_handler(self, exception_cls: Type[Exception]):
        """Register a new error handler (decorator syntax).

        # See Also
        - [add_error_handler](#add-error-handler)
        """
        def wrapper(handler):
            self.add_error_handler(exception_cls, handler)
            return handler

        return wrapper

    def add_middleware(self, middleware_cls, **kwargs):
        """Register a middleware class.

        # Parameters
        middleware_cls: a subclass of #::bocadillo.middleware#Middleware.

        # See Also
        - [Middleware](../guides/http/middleware.md)
        """
        self.exception_middleware.app = middleware_cls(
            self.exception_middleware.app, app=self, **kwargs)

    def add_asgi_middleware(self, middleware_cls, **kwargs):
        """Register an ASGI middleware class.

        # Parameters
        middleware_cls: a class that complies with the ASGI specification.

        # See Also
        - [ASGI middleware](../guides/agnostic/asgi-middleware.md)
        - [ASGI](https://asgi.readthedocs.io)
        """
        args = (self, ) if issubclass(middleware_cls, ASGIMiddleware) else ()
        self.asgi = middleware_cls(self.asgi, *args, **kwargs)

    def on(self, event: str, handler: Optional[EventHandler] = None):
        """Register an event handler.

        # Parameters
        event (str):
            Either `"startup"` (when the server boots) or `"shutdown"`
            (when the server stops).
        handler (callback, optional):
            The event handler. If not given, this should be used as a
            decorator.

        # Example

        ```python
        @app.on("startup")
        async def startup():
            pass

        async def shutdown():
            pass

        app.on("shutdown", shutdown)
        ```
        """
        if handler is None:

            def register(func):
                self._lifespan.add_event_handler(event, func)
                return func

            return register

        self._lifespan.add_event_handler(event, handler)
        return handler

    async def dispatch_http(self, receive: Receive, send: Send, scope: Scope):
        req = Request(scope, receive)
        res = Response(
            req,
            media_type=self.media_type,
            media_handler=self.media_handlers[self.media_type],
        )

        with self._http_context.assign(req=req, res=res):
            res: Response = await self.server_error_middleware(req, res)
            await res(receive, send)
            # Re-raise the exception to allow the server to log the error
            # and for the test client to optionally re-raise it too.
            self.server_error_middleware.raise_if_exception()

    async def dispatch_websocket(self, receive: Receive, send: Send,
                                 scope: Scope):
        await self.websocket_router(scope, receive, send)

    def dispatch(self, scope: Scope) -> ASGIAppInstance:
        with self._app_providers():
            path: str = scope["path"]

            # Return a sub-mounted extra app, if found
            for prefix, app in self._prefix_to_app.items():
                if not path.startswith(prefix):
                    continue
                # Remove prefix from path so that the request is made according
                # to the mounted app's point of view.
                scope["path"] = path[len(prefix):]
                try:
                    return app(scope)
                except TypeError:
                    return WSGIResponder(app, scope)

            if scope["type"] == "websocket":
                return partial(self.dispatch_websocket, scope=scope)

            assert scope["type"] == "http"
            return partial(self.dispatch_http, scope=scope)

    def __call__(self, scope: Scope) -> ASGIAppInstance:
        if scope["type"] == "lifespan":
            return self._lifespan(scope)
        return self.asgi(scope)

    def run(
        self,
        host: str = None,
        port: int = None,
        debug: bool = None,
        declared_as: str = "app",
        _run: Callable = None,
        **kwargs,
    ):
        """Serve the application using [uvicorn](https://www.uvicorn.org).

        # Parameters

        host (str):
            The host to bind to.
            Defaults to `"127.0.0.1"` (localhost).
            If not given and `$PORT` is set, `"0.0.0.0"` will be used to
            serve to all known hosts.
        port (int):
            The port to bind to.
            Defaults to `8000` or (if set) the value of the `$PORT` environment
            variable.
        debug (bool):
            Whether to serve the application in debug mode. Defaults to `False`,
            except if the `BOCADILLO_DEBUG` environment variable is set.
        declared_as (str):
            The name under which the application is declared.
            This is only used when `debug=True` to indicate to
            uvicorn how to import the application object.
            Defaults to `"app"`.
        kwargs (dict):
            Extra keyword arguments passed to the uvicorn runner.

        # See Also
        - [Configuring host and port](../guides/app.md#configuring-host-and-port)
        - [Debug mode](../guides/app.md#debug-mode)
        - [Uvicorn settings](https://www.uvicorn.org/settings/) for all
        available keyword arguments.
        """
        if _run is None:  # pragma: no cover
            _run = run

        if "PORT" in os.environ:
            port = int(os.environ["PORT"])
            if host is None:
                host = "0.0.0.0"

        if host is None:
            host = "127.0.0.1"

        if port is None:
            port = 8000

        if debug is None:
            debug = os.environ.get("BOCADILLO_DEBUG", False)

        if debug:
            self.debug = kwargs["debug"] = True

            # Reload static files in development.
            # See: http://whitenoise.evans.io/en/stable/base.html#autorefresh
            for whitenoise in self._static_apps.values():
                whitenoise.autorefresh = True

            if self.import_string is None:
                # The import string could not be inferred.
                # We're probaby in the REPL.
                target = self
                warnings.warn(
                    "Could not infer application module. "
                    "uvicorn won't be able to hot reload on changes.")
            else:
                target = f"{self.import_string}:{declared_as}"
        else:
            target = self

        _run(target, host=host, port=port, **kwargs)
Esempio n. 3
0
class App(RoutingMixin, metaclass=DocsMeta):
    """The all-mighty application class.

    This class implements the [ASGI](https://asgi.readthedocs.io) protocol.

    [CORSMiddleware]: https://www.starlette.io/middleware/#corsmiddleware
    [SessionMiddleware]: https://www.starlette.io/middleware/#sessionmiddleware

    # Example

    ```python
    >>> from bocadillo import App
    >>> app = App()
    ```

    # Parameters
    name (str):
        An optional name for the app.
    static_dir (str):
        The name of the directory containing static files, relative to
        the application entry point. Set to `None` to not serve any static
        files.
        Defaults to `"static"`.
    static_root (str):
        The path prefix for static assets.
        Defaults to `"static"`.
    static_config (dict):
        Extra static files configuration attributes.
        See also #::bocadillo.staticfiles#static.
    allowed_hosts (list of str, optional):
        A list of hosts which the server is allowed to run at.
        If the list contains `"*"`, any host is allowed.
        Defaults to `["*"]`.
    enable_sessions (bool):
        If `True`, cookie-based signed sessions are enabled according to the
        `sessions_config`. The secret key must be non-empty and can also be
        set via the `SECRET_KEY` environment variable.
        Defaults to `False`.
    sessions_config (dict):
        A dictionary of sessions configuration parameters.
        See [SessionMiddleware].
    enable_cors (bool):
        If `True`, Cross Origin Resource Sharing are configured according
        to `cors_config`. Defaults to `False`.
        See also [CORS](../guides/http/middleware.md#cors).
    cors_config (dict):
        A dictionary of CORS configuration parameters.
        Defaults to `dict(allow_origins=[], allow_methods=["GET"])`.
        See [CORSMiddleware].
    enable_hsts (bool):
        If `True`, enable HSTS (HTTP Strict Transport Security) and automatically
        redirect HTTP traffic to HTTPS.
        Defaults to `False`.
        See also [HSTS](../guides/http/middleware.md#hsts).
    enable_gzip (bool):
        If `True`, enable GZip compression and automatically
        compress responses for clients that support it.
        Defaults to `False`.
        See also [GZip](../guides/http/middleware.md#gzip).
    gzip_min_size (int):
        If specified, compress only responses that
        have more bytes than the specified value.
        Defaults to `1024`.
    media_type (str):
        Determines how values given to `res.media` are serialized.
        Can be one of the supported media types.
        Defaults to `"application/json"`.
        See also [Media](../guides/http/media.md).

    # Attributes
    media_handlers (dict):
        The dictionary of media handlers.
        You can access, edit or replace this at will.
    """

    __slots__ = (
        "name",
        "asgi",
        "_prefix_to_app",
        "_name_to_prefix_and_app",
        "_static_apps",
        "media_handlers",
        "_media_type",
        "exception_middleware",
        "server_error_middleware",
        "_lifespan",
        "_store",
        "_frozen",
    )

    def __init__(
        self,
        name: str = None,
        *,
        static_dir: Optional[str] = "static",
        static_root: Optional[str] = "static",
        static_config: dict = None,
        allowed_hosts: List[str] = None,
        enable_sessions: bool = False,
        sessions_config: dict = None,
        enable_cors: bool = False,
        cors_config: dict = None,
        enable_hsts: bool = False,
        enable_gzip: bool = False,
        gzip_min_size: int = 1024,
        media_type: str = CONTENT_TYPE.JSON,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.name = name

        # Base ASGI app
        self.asgi = self.dispatch

        # Mounted (children) apps
        self._prefix_to_app: Dict[str, Any] = {}
        self._name_to_prefix_and_app: Dict[str, Tuple[str, App]] = {}
        self._static_apps: Dict[str, WhiteNoise] = {}

        # Static files
        if static_dir is not None:
            if static_root is None:
                static_root = static_dir
            self.mount(static_root, static(static_dir, **(static_config
                                                          or {})))

        # Media
        self.media_handlers = get_default_handlers()
        self._media_type = ""
        self.media_type = media_type

        # HTTP middleware
        self.exception_middleware = HTTPErrorMiddleware(self.http_router)
        self.server_error_middleware = ServerErrorMiddleware(
            self.exception_middleware, handler=error_to_text)
        self.add_error_handler(HTTPError, error_to_text)
        self.add_error_handler(typesystem.ValidationError, on_validation_error)

        # Lifespan middleware
        self._lifespan = Lifespan()

        # ASGI middleware

        if allowed_hosts is None:
            allowed_hosts = ["*"]
        self.add_asgi_middleware(TrustedHostMiddleware,
                                 allowed_hosts=allowed_hosts)

        if enable_sessions:
            sessions_config = sessions_config or {}

            try:
                from starlette.middleware.sessions import SessionMiddleware
            except ImportError as exc:  # pragma: no cover
                if "itsdangerous" in str(exc):
                    raise ImportError(
                        "Please install the [sessions] extra to use sessions: "
                        "`pip install bocadillo[sessions]`.") from exc
                raise exc from None

            secret_key = sessions_config.pop("secret_key", None)
            if secret_key is None:
                secret_key = os.environ.get("SECRET_KEY", "")

            if not secret_key:
                raise MissingSecretKey(
                    "A non-empty secret key must be set to use sessions. "
                    "Pass a 'secret_key' to 'session_config', or set the "
                    "SECRET_KEY environment variable.")

            sessions_config["secret_key"] = secret_key

            self.add_asgi_middleware(SessionMiddleware, **sessions_config)

        if enable_cors:
            cors_config = {**DEFAULT_CORS_CONFIG, **(cors_config or {})}
            self.add_asgi_middleware(CORSMiddleware, **cors_config)

        if enable_hsts:
            self.add_asgi_middleware(HTTPSRedirectMiddleware)

        if enable_gzip:
            self.add_asgi_middleware(GZipMiddleware,
                                     minimum_size=gzip_min_size)

        # Providers.

        self._store = _STORE

        # NOTE: discover providers from `providerconf` at instanciation time,
        # so that further declared views correctly resolve providers.
        self._store.discover_default()

        self.on("startup", self._store.enter_session)
        self.on("shutdown", self._store.exit_session)

        self._frozen = False

    def _app_providers(self):  # pylint: disable=method-hidden
        if not self._frozen:
            self._store.freeze()
            self._frozen = True
        return nullcontext()

    @property
    def media_type(self) -> str:
        """The media type configured when instanciating the application."""
        return self._media_type

    @media_type.setter
    def media_type(self, media_type: str):
        if media_type not in self.media_handlers:
            raise UnsupportedMediaType(media_type,
                                       handlers=self.media_handlers)
        self._media_type = media_type

    def url_for(self, name: str, **kwargs) -> str:
        # Implement route name lookup accross sub-apps.
        try:
            return super().url_for(name, **kwargs)
        except HTTPError as exc:
            app_name, _, name = name.partition(":")

            if not name:
                # No app name given.
                raise exc from None

            return self._url_for_app(app_name, name, **kwargs)

    def _url_for_app(self, app_name: str, name: str, **kwargs) -> str:
        if app_name == self.name:
            # NOTE: this allows to reference this app's routes in
            # both with or without the namespace.
            return self._get_own_url_for(name, **kwargs)

        try:
            prefix, app = self._name_to_prefix_and_app[app_name]
        except KeyError as key_exc:
            raise HTTPError(404) from key_exc
        else:
            return prefix + app.url_for(name, **kwargs)

    def _get_own_url_for(self, name: str, **kwargs) -> str:
        # NOTE: recipes hook into this method to prepend their
        # prefix to the URL.
        return super().url_for(name, **kwargs)

    def mount(self, prefix: str, app: Union["App", ASGIApp, WSGIApp]):
        """Mount another WSGI or ASGI app at the given prefix.

        [WSGI]: https://wsgi.readthedocs.io
        [ASGI]: https://asgi.readthedocs.io

        # Parameters
        prefix (str):
            A path prefix where the app should be mounted, e.g. `"/myapp"`.
        app:
            an object implementing the [WSGI] or [ASGI] protocol.
        """
        if not prefix.startswith("/"):
            prefix = "/" + prefix

        self._prefix_to_app[prefix] = app

        if isinstance(app, App) and app.name is not None:
            self._name_to_prefix_and_app[app.name] = (prefix, app)

        if isinstance(app, WhiteNoise):
            self._static_apps[prefix] = app

    def recipe(self, recipe: "Recipe"):
        """Apply a recipe.

        # Parameters
        recipe:
            a #::bocadillo.recipes#Recipe or #::bocadillo.recipes#RecipeBook
            to be applied to the application.

        # See Also
        - [Recipes](../guides/agnostic/recipes.md)
        """
        recipe.apply(self)

    def add_error_handler(self, exception_cls: Type[_E],
                          handler: ErrorHandler):
        """Register a new error handler.

        # Parameters
        exception_cls (exception class):
            The type of exception that should be handled.
        handler (callable):
            The actual error handler, which is called when an instance of
            `exception_cls` is caught.
            Should accept a request, response and exception parameters.
        """
        self.exception_middleware.add_exception_handler(exception_cls, handler)

    def error_handler(self, exception_cls: Type[Exception]):
        """Register a new error handler (decorator syntax).

        # See Also
        - [add_error_handler](#add-error-handler)
        """
        def wrapper(handler):
            self.add_error_handler(exception_cls, handler)
            return handler

        return wrapper

    def add_middleware(self, middleware_cls, **kwargs):
        """Register a middleware class.

        # Parameters
        middleware_cls: a subclass of #::bocadillo.middleware#Middleware.

        # See Also
        - [Middleware](../guides/http/middleware.md)
        """
        self.exception_middleware.app = middleware_cls(
            self.exception_middleware.app, app=self, **kwargs)

    def add_asgi_middleware(self, middleware_cls, **kwargs):
        """Register an ASGI middleware class.

        # Parameters
        middleware_cls: a class that complies with the ASGI specification.

        # See Also
        - [ASGI middleware](../guides/agnostic/asgi-middleware.md)
        - [ASGI](https://asgi.readthedocs.io)
        """
        args = (self, ) if issubclass(middleware_cls, ASGIMiddleware) else ()
        self.asgi = middleware_cls(self.asgi, *args, **kwargs)

    def on(self, event: str, handler: Optional[EventHandler] = None):
        """Register an event handler.

        # Parameters
        event (str):
            Either `"startup"` (when the server boots) or `"shutdown"`
            (when the server stops).
        handler (callback, optional):
            The event handler. If not given, this should be used as a
            decorator.

        # Example

        ```python
        @app.on("startup")
        async def startup():
            pass

        async def shutdown():
            pass

        app.on("shutdown", shutdown)
        ```
        """
        if handler is None:

            def register(func):
                self._lifespan.add_event_handler(event, func)
                return func

            return register

        self._lifespan.add_event_handler(event, handler)
        return handler

    async def dispatch_http(self, receive: Receive, send: Send, scope: Scope):
        req = Request(scope, receive)
        res = Response(
            req,
            media_type=self.media_type,
            media_handler=self.media_handlers[self.media_type],
        )

        res: Response = await self.server_error_middleware(req, res)
        await res(receive, send)
        # Re-raise the exception to allow the server to log the error
        # and for the test client to optionally re-raise it too.
        self.server_error_middleware.raise_if_exception()

    async def dispatch_websocket(self, receive: Receive, send: Send,
                                 scope: Scope):
        await self.websocket_router(scope, receive, send)

    def dispatch(self, scope: Scope) -> ASGIAppInstance:
        with self._app_providers():
            path: str = scope["path"]

            # Return a sub-mounted extra app, if found
            for prefix, app in self._prefix_to_app.items():
                if not path.startswith(prefix):
                    continue
                # Remove prefix from path so that the request is made according
                # to the mounted app's point of view.
                scope["path"] = path[len(prefix):]
                try:
                    return app(scope)
                except TypeError:
                    return WSGIResponder(app, scope)

            if scope["type"] == "websocket":
                return partial(self.dispatch_websocket, scope=scope)

            assert scope["type"] == "http"
            return partial(self.dispatch_http, scope=scope)

    def __call__(self, scope: Scope) -> ASGIAppInstance:
        if scope["type"] == "lifespan":
            return self._lifespan(scope)
        return self.asgi(scope)
Esempio n. 4
0
class Router:
    __slots__ = ("routes", "lifespan")

    def __init__(self):
        self.routes: typing.List[BaseRoute] = []
        self.lifespan = Lifespan()

    def add_route(self, route: BaseRoute) -> None:
        self.routes.append(route)

    def include(self, other: "Router", prefix: str = ""):
        """Include the routes of another router."""
        for route in other.routes:
            assert isinstance(route, (HTTPRoute, WebSocketRoute, Mount))
            if prefix:
                if isinstance(route, HTTPRoute):
                    route = HTTPRoute(pattern=_join(prefix, route.pattern),
                                      view=route.view)
                elif isinstance(route, WebSocketRoute):
                    route = WebSocketRoute(
                        pattern=_join(prefix, route.pattern),
                        view=route.view,
                        **route.ws_kwargs,
                    )
                else:
                    route = Mount(path=_join(prefix, route.path),
                                  app=route.app)
            self.add_route(route)

    def mount(self, path: str, app: ASGIApp):
        """Mount an ASGI or WSGI app at the given path."""
        return self.add_route(Mount(path, app))

    def on(self, event: str, handler: typing.Optional[EventHandler] = None):
        if handler is None:

            def register(func):
                self.lifespan.add_event_handler(event, func)
                return func

            return register

        self.lifespan.add_event_handler(event, handler)
        return handler

    def route(self, pattern: str, methods: typing.List[str] = None):
        """Register an HTTP route by decorating a view.

        # Parameters
        pattern (str): an URL pattern.
        """
        def decorate(view: typing.Any) -> HTTPRoute:
            view = View(view, methods=methods)
            route = HTTPRoute(pattern, view)
            self.add_route(route)
            return route

        return decorate

    def websocket_route(
        self,
        pattern: str,
        *,
        auto_accept: bool = True,
        value_type: str = None,
        receive_type: str = None,
        send_type: str = None,
        caught_close_codes: typing.Tuple[int] = None,
    ):
        """Register a WebSocket route by decorating a view.

        See #::bocadillo.websockets#WebSocket for a description of keyword
        parameters.

        # Parameters
        pattern (str): an URL pattern.
        """
        def decorate(view: typing.Any) -> WebSocketRoute:
            view = WebSocketView(view)
            route = WebSocketRoute(
                pattern,
                view,
                auto_accept=auto_accept,
                value_type=value_type,
                receive_type=receive_type,
                send_type=send_type,
                caught_close_codes=caught_close_codes,
            )
            self.add_route(route)
            return route

        return decorate

    def _find_route(self, scope: dict) -> typing.Optional[BaseRoute]:
        for route in self.routes:
            matches, child_scope = route.matches(scope)
            if matches:
                scope.update(child_scope)
                return route
        return None

    async def __call__(self, scope: Scope, receive: Receive, send: Send):
        scope["send"] = send  # See: `RequestResponseMiddleware`.

        if scope["type"] == "lifespan":
            await self.lifespan(scope, receive, send)
            return

        route = self._find_route(scope)

        if route is not None:
            try:
                await route(scope, receive, send)
                return
            except Redirect as exc:
                scope["res"] = exc.response
                return

        try_http_redirect = (scope["type"] == "http"
                             and not scope["path"].endswith("/")
                             and redirect_trailing_slash_enabled())

        if try_http_redirect:
            redirect_scope = dict(scope)
            redirect_scope["path"] += "/"
            route = self._find_route(redirect_scope)
            if route is not None:
                redirect_url = URL(scope=redirect_scope)
                scope["res"] = Redirect(str(redirect_url)).response
                return

        if scope["type"] == "websocket":
            await WebSocketClose(code=403)(receive, send)
            return

        raise HTTPError(404)