Exemplo n.º 1
0
 def __init__(self,
              debug: bool = False,
              template_directory: str = None) -> None:
     self._debug = debug
     self.router = Router()
     self.exception_middleware = ExceptionMiddleware(self.router,
                                                     debug=debug)
     self.error_middleware = ServerErrorMiddleware(
         self.exception_middleware, debug=debug)
     self.lifespan_middleware = LifespanMiddleware(self.error_middleware)
     self.schema_generator = None  # type: typing.Optional[BaseSchemaGenerator]
     self.template_env = self.load_template_env(template_directory)
Exemplo n.º 2
0
    def __init__(
        self,
        debug: bool = False,
        template_directory: str = None,
        title: str = "Fast API",
        description: str = "",
        version: str = "0.1.0",
        openapi_url: Optional[str] = "/openapi.json",
        docs_url: Optional[str] = "/docs",
        redoc_url: Optional[str] = "/redoc",
        **extra: Dict[str, Any],
    ) -> None:
        self._debug = debug
        self.router: routing.APIRouter = routing.APIRouter()
        self.exception_middleware = ExceptionMiddleware(self.router,
                                                        debug=debug)
        self.error_middleware = ServerErrorMiddleware(
            self.exception_middleware, debug=debug)
        self.lifespan_middleware = LifespanMiddleware(self.error_middleware)
        self.schema_generator = None
        self.template_env = self.load_template_env(template_directory)

        self.title = title
        self.description = description
        self.version = version
        self.openapi_url = openapi_url
        self.docs_url = docs_url
        self.redoc_url = redoc_url
        self.extra = extra

        self.openapi_version = "3.0.2"

        if self.openapi_url:
            assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'"
            assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'"

        if self.docs_url or self.redoc_url:
            assert self.openapi_url, "The openapi_url is required for the docs"
        self.openapi_schema: Optional[Dict[str, Any]] = None
        self.setup()
Exemplo n.º 3
0
def test_async_lifespan_handler():
    startup_complete = False
    cleanup_complete = False
    handler = LifespanMiddleware(App)

    @handler.on_event("startup")
    async def run_startup():
        nonlocal startup_complete
        startup_complete = True

    @handler.on_event("shutdown")
    async def run_cleanup():
        nonlocal cleanup_complete
        cleanup_complete = True

    assert not startup_complete
    assert not cleanup_complete
    with TestClient(handler):
        assert startup_complete
        assert not cleanup_complete
    assert startup_complete
    assert cleanup_complete
Exemplo n.º 4
0
    def __init__(
        self,
        *,
        debug=False,
        title=None,
        version=None,
        description=None,
        terms_of_service=None,
        contact=None,
        license=None,
        openapi=None,
        openapi_route="/schema.yml",
        static_dir="static",
        static_route="/static",
        templates_dir="templates",
        auto_escape=True,
        secret_key=DEFAULT_SECRET_KEY,
        enable_hsts=False,
        docs_route=None,
        cors=False,
        cors_params=DEFAULT_CORS_PARAMS,
        allowed_hosts=None,
    ):
        self.background = BackgroundQueue()

        self.secret_key = secret_key
        self.title = title
        self.version = version
        self.description = description
        self.terms_of_service = terms_of_service
        self.contact = contact
        self.license = license
        self.openapi_version = openapi

        if static_dir is not None:
            if static_route is None:
                static_route = static_dir
            static_dir = Path(os.path.abspath(static_dir))

        self.static_dir = static_dir
        self.static_route = static_route

        self.built_in_templates_dir = Path(
            os.path.abspath(os.path.dirname(__file__) + "/templates"))

        if templates_dir is not None:
            templates_dir = Path(os.path.abspath(templates_dir))

        self.templates_dir = templates_dir or self.built_in_templates_dir

        self.apps = {}
        self.routes = {}
        self.before_requests = {"http": [], "ws": []}
        self.docs_theme = DEFAULT_API_THEME
        self.docs_route = docs_route
        self.schemas = {}
        self.session_cookie = DEFAULT_SESSION_COOKIE

        self.hsts_enabled = enable_hsts
        self.cors = cors
        self.cors_params = cors_params
        self.debug = debug

        if not allowed_hosts:
            # if not debug:
            #     raise RuntimeError(
            #         "You need to specify `allowed_hosts` when debug is set to False"
            #     )
            allowed_hosts = ["*"]
        self.allowed_hosts = allowed_hosts

        # Make the static/templates directory if they don't exist.
        for _dir in (self.static_dir, self.templates_dir):
            if _dir is not None:
                os.makedirs(_dir, exist_ok=True)

        if self.static_dir is not None:
            self.whitenoise = WhiteNoise(application=self._notfound_wsgi_app)
            self.whitenoise.add_files(str(self.static_dir))

            self.whitenoise.add_files(
                (Path(apistar.__file__).parent / "themes" / self.docs_theme /
                 "static").resolve())

            self.mount(self.static_route, self.whitenoise)

        self.formats = get_formats()

        # Cached requests session.
        self._session = None

        if self.openapi_version:
            self.add_route(openapi_route, self.schema_response)

        if self.docs_route:
            self.add_route(self.docs_route, self.docs_response)

        self.default_endpoint = None
        self.app = self.dispatch
        self.add_middleware(GZipMiddleware)

        if self.hsts_enabled:
            self.add_middleware(HTTPSRedirectMiddleware)

        self.add_middleware(TrustedHostMiddleware,
                            allowed_hosts=self.allowed_hosts)

        self.lifespan_handler = LifespanMiddleware(LifespanHandler)

        if self.cors:
            self.add_middleware(CORSMiddleware, **self.cors_params)
        self.add_middleware(ServerErrorMiddleware, debug=debug)

        # Jinja enviroment
        self.jinja_env = jinja2.Environment(
            loader=jinja2.FileSystemLoader(
                [str(self.templates_dir),
                 str(self.built_in_templates_dir)],
                followlinks=True,
            ),
            autoescape=jinja2.select_autoescape(
                ["html", "xml"] if auto_escape else []),
        )
        self.jinja_values_base = {"api": self}  # Give reference to self.
        self.requests = (
            self.session()
        )  #: A Requests session that is connected to the ASGI app.
Exemplo n.º 5
0
class API:
    """The primary web-service class.

        :param static_dir: The directory to use for static files. Will be created for you if it doesn't already exist.
        :param templates_dir: The directory to use for templates. Will be created for you if it doesn't already exist.
        :param auto_escape: If ``True``, HTML and XML templates will automatically be escaped.
        :param enable_hsts: If ``True``, send all responses to HTTPS URLs.
        :param title: The title of the application (OpenAPI Info Object)
        :param version: The version of the OpenAPI document (OpenAPI Info Object)
        :param description: The description of the OpenAPI document (OpenAPI Info Object)
        :param terms_of_service: A URL to the Terms of Service for the API (OpenAPI Info Object)
        :param contact: The contact dictionary of the application (OpenAPI Contact Object)
        :param license: The license information of the exposed API (OpenAPI License Object)
    """

    status_codes = status_codes

    def __init__(
        self,
        *,
        debug=False,
        title=None,
        version=None,
        description=None,
        terms_of_service=None,
        contact=None,
        license=None,
        openapi=None,
        openapi_route="/schema.yml",
        static_dir="static",
        static_route="/static",
        templates_dir="templates",
        auto_escape=True,
        secret_key=DEFAULT_SECRET_KEY,
        enable_hsts=False,
        docs_route=None,
        cors=False,
        cors_params=DEFAULT_CORS_PARAMS,
        allowed_hosts=None,
    ):
        self.background = BackgroundQueue()

        self.secret_key = secret_key
        self.title = title
        self.version = version
        self.description = description
        self.terms_of_service = terms_of_service
        self.contact = contact
        self.license = license
        self.openapi_version = openapi

        if static_dir is not None:
            if static_route is None:
                static_route = static_dir
            static_dir = Path(os.path.abspath(static_dir))

        self.static_dir = static_dir
        self.static_route = static_route

        self.built_in_templates_dir = Path(
            os.path.abspath(os.path.dirname(__file__) + "/templates"))

        if templates_dir is not None:
            templates_dir = Path(os.path.abspath(templates_dir))

        self.templates_dir = templates_dir or self.built_in_templates_dir

        self.apps = {}
        self.routes = {}
        self.before_requests = {"http": [], "ws": []}
        self.docs_theme = DEFAULT_API_THEME
        self.docs_route = docs_route
        self.schemas = {}
        self.session_cookie = DEFAULT_SESSION_COOKIE

        self.hsts_enabled = enable_hsts
        self.cors = cors
        self.cors_params = cors_params
        self.debug = debug

        if not allowed_hosts:
            # if not debug:
            #     raise RuntimeError(
            #         "You need to specify `allowed_hosts` when debug is set to False"
            #     )
            allowed_hosts = ["*"]
        self.allowed_hosts = allowed_hosts

        # Make the static/templates directory if they don't exist.
        for _dir in (self.static_dir, self.templates_dir):
            if _dir is not None:
                os.makedirs(_dir, exist_ok=True)

        if self.static_dir is not None:
            self.whitenoise = WhiteNoise(application=self._notfound_wsgi_app)
            self.whitenoise.add_files(str(self.static_dir))

            self.whitenoise.add_files(
                (Path(apistar.__file__).parent / "themes" / self.docs_theme /
                 "static").resolve())

            self.mount(self.static_route, self.whitenoise)

        self.formats = get_formats()

        # Cached requests session.
        self._session = None

        if self.openapi_version:
            self.add_route(openapi_route, self.schema_response)

        if self.docs_route:
            self.add_route(self.docs_route, self.docs_response)

        self.default_endpoint = None
        self.app = self.dispatch
        self.add_middleware(GZipMiddleware)

        if self.hsts_enabled:
            self.add_middleware(HTTPSRedirectMiddleware)

        self.add_middleware(TrustedHostMiddleware,
                            allowed_hosts=self.allowed_hosts)

        self.lifespan_handler = LifespanMiddleware(LifespanHandler)

        if self.cors:
            self.add_middleware(CORSMiddleware, **self.cors_params)
        self.add_middleware(ServerErrorMiddleware, debug=debug)

        # Jinja enviroment
        self.jinja_env = jinja2.Environment(
            loader=jinja2.FileSystemLoader(
                [str(self.templates_dir),
                 str(self.built_in_templates_dir)],
                followlinks=True,
            ),
            autoescape=jinja2.select_autoescape(
                ["html", "xml"] if auto_escape else []),
        )
        self.jinja_values_base = {"api": self}  # Give reference to self.
        self.requests = (
            self.session()
        )  #: A Requests session that is connected to the ASGI app.

    @staticmethod
    def _default_wsgi_app(environ, start_response):
        pass

    @staticmethod
    def _notfound_wsgi_app(environ, start_response):
        start_response("404 NOT FOUND", [("Content-Type", "text/plain")])
        return [b"Not Found."]

    def before_request(self, websocket=False):
        def decorator(f):
            if websocket:
                self.before_requests.setdefault("ws", []).append(f)
            else:
                self.before_requests.setdefault("http", []).append(f)
            return f

        return decorator

    @property
    def before_http_requests(self):
        return self.before_requests.get("http", [])

    @property
    def before_ws_requests(self):
        return self.before_requests.get("ws", [])

    @property
    def _apispec(self):

        info = {}
        if self.description is not None:
            info["description"] = self.description
        if self.terms_of_service is not None:
            info["termsOfService"] = self.terms_of_service
        if self.contact is not None:
            info["contact"] = self.contact
        if self.license is not None:
            info["license"] = self.license

        spec = APISpec(
            title=self.title,
            version=self.version,
            openapi_version=self.openapi_version,
            plugins=[MarshmallowPlugin()],
            info=info,
        )

        for route in self.routes:
            if self.routes[route].description:
                operations = yaml_utils.load_operations_from_docstring(
                    self.routes[route].description)
                spec.path(path=route, operations=operations)

        for name, schema in self.schemas.items():
            spec.components.schema(name, schema=schema)

        return spec

    @property
    def openapi(self):
        return self._apispec.to_yaml()

    def add_middleware(self, middleware_cls, **middleware_config):
        self.app = middleware_cls(self.app, **middleware_config)

    def __call__(self, scope):

        if scope["type"] == "lifespan":
            return self.lifespan_handler(scope)

        path = scope["path"]
        root_path = scope.get("root_path", "")

        # Call into a submounted app, if one exists.
        for path_prefix, app in self.apps.items():
            if path.startswith(path_prefix):
                scope["path"] = path[len(path_prefix):]
                scope["root_path"] = root_path + path_prefix
                try:
                    return app(scope)
                except TypeError:
                    app = WsgiToAsgi(app)
                    return app(scope)

        return self.app(scope)

    def dispatch(self, scope):
        # Call the main dispatcher.
        async def asgi(receive, send):
            nonlocal scope, self

            if scope["type"] == "lifespan":
                return self.lifespan_handler(scope)
            elif scope["type"] == "websocket":
                await self._dispatch_ws(scope=scope,
                                        receive=receive,
                                        send=send)
            else:
                req = models.Request(scope, receive=receive, api=self)
                resp = await self._dispatch_request(req,
                                                    scope=scope,
                                                    send=send,
                                                    receive=receive)
                await resp(receive, send)

        return asgi

    async def _dispatch_ws(self, scope, receive, send):
        ws = WebSocket(scope=scope, receive=receive, send=send)

        route = self.path_matches_route(ws.url.path)
        route = self.routes.get(route)

        if route:
            for before_request in self.before_ws_requests:
                await self.background(before_request, ws=ws)
            await self.background(route.endpoint, ws)
        else:
            await send({"type": "websocket.close", "code": 1000})

    def add_schema(self, name, schema, check_existing=True):
        """Adds a mashmallow schema to the API specification."""
        if check_existing:
            assert name not in self.schemas

        self.schemas[name] = schema

    def schema(self, name, **options):
        """Decorator for creating new routes around function and class definitions.

        Usage::

            from marshmallow import Schema, fields

            @api.schema("Pet")
            class PetSchema(Schema):
                name = fields.Str()

        """
        def decorator(f):
            self.add_schema(name=name, schema=f, **options)
            return f

        return decorator

    def path_matches_route(self, path):
        """Given a path portion of a URL, tests that it matches against any registered route.

        :param path: The path portion of a URL, to test all known routes against.
        """
        for (route, route_object) in self.routes.items():
            if route_object.does_match(path):
                return route

    @property
    def _signer(self):
        return itsdangerous.Signer(self.secret_key)

    def _prepare_session(self, resp):

        if resp.session:
            data = self._signer.sign(
                b64encode(json.dumps(resp.session).encode("utf-8")))
            resp.cookies[self.session_cookie] = data.decode("utf-8")

    @staticmethod
    def no_response(req, resp, **params):
        pass

    async def _dispatch_request(self, req, **options):
        # Set formats on Request object.
        req.formats = self.formats

        # Get the route.
        route = self.path_matches_route(req.url.path)
        route = self.routes.get(route)
        if route:
            resp = models.Response(req=req, formats=self.formats)

            for before_request in self.before_http_requests:
                await self.background(before_request, req=req, resp=resp)

            await self._execute_route(route=route,
                                      req=req,
                                      resp=resp,
                                      **options)
        else:
            resp = models.Response(req=req, formats=self.formats)
            self.default_response(req=req, resp=resp, notfound=True)
        self.default_response(req=req, resp=resp)

        self._prepare_session(resp)

        return resp

    async def _execute_route(self, *, route, req, resp, **options):

        params = route.incoming_matches(req.url.path)

        cont = True

        if route.is_function:
            try:
                try:
                    # Run the view.
                    r = self.background(route.endpoint, req, resp, **params)
                    # If it's async, await it.
                    if hasattr(r, "cr_running"):
                        await r
                except TypeError as e:
                    cont = True
            except Exception:
                await self.background(self.default_response,
                                      req,
                                      resp,
                                      error=True)
                raise

        if route.is_class_based or cont:
            try:
                view = route.endpoint(**params)
            except TypeError:
                try:
                    view = route.endpoint()
                except TypeError:
                    view = route.endpoint
                    pass

            # Run on_request first.
            try:
                # Run the view.
                r = getattr(view, "on_request", self.no_response)
                r = self.background(r, req, resp, **params)
                # If it's async, await it.
                if hasattr(r, "send"):
                    await r
            except Exception:
                await self.background(self.default_response,
                                      req,
                                      resp,
                                      error=True)
                raise

            # Then run on_method.
            method = req.method
            try:
                # Run the view.
                r = getattr(view, f"on_{method}", self.no_response)
                r = self.background(r, req, resp, **params)
                # If it's async, await it.
                if hasattr(r, "send"):
                    await r
            except Exception:
                await self.background(self.default_response,
                                      req,
                                      resp,
                                      error=True)
                raise

    def add_event_handler(self, event_type, handler):
        """Adds an event handler to the API.

        :param event_type: A string in ("startup", "shutdown")
        :param handler: The function to run. Can be either a function or a coroutine.
        """

        self.lifespan_handler.add_event_handler(event_type, handler)

    def add_route(
        self,
        route=None,
        endpoint=None,
        *,
        default=False,
        static=False,
        check_existing=True,
        websocket=False,
        before_request=False,
    ):
        """Adds a route to the API.

        :param route: A string representation of the route.
        :param endpoint: The endpoint for the route -- can be a callable, or a class.
        :param default: If ``True``, all unknown requests will route to this view.
        :param static: If ``True``, and no endpoint was passed, render "static/index.html", and it will become a default route.
        :param check_existing: If ``True``, an AssertionError will be raised, if the route is already defined.
        """
        if before_request:
            if websocket:
                self.before_requests.setdefault("ws", []).append(endpoint)
            else:
                self.before_requests.setdefault("http", []).append(endpoint)
            return

        if route is None:
            route = f"/{uuid4().hex}"

        if check_existing:
            assert route not in self.routes

        if static:
            assert self.static_dir is not None
            if not endpoint:
                endpoint = self.static_response
                default = True

        if default:
            self.default_endpoint = endpoint

        self.routes[route] = Route(route, endpoint, websocket=websocket)
        # TODO: A better data structure or sort it once the app is loaded
        self.routes = dict(
            sorted(self.routes.items(), key=lambda item: item[1]._weight()))

    def default_response(self,
                         req=None,
                         resp=None,
                         websocket=False,
                         notfound=False,
                         error=False):
        if websocket:
            return

        if resp.status_code is None:
            resp.status_code = 200

        if self.default_endpoint and notfound:
            self.default_endpoint(req=req, resp=resp)
        else:
            if notfound:
                resp.status_code = status_codes.HTTP_404
                resp.text = "Not found."
            if error:
                resp.status_code = status_codes.HTTP_500
                resp.text = "Application error."

    def docs_response(self, req, resp):
        resp.html = self.docs

    def static_response(self, req, resp):

        assert self.static_dir is not None

        index = (self.static_dir / "index.html").resolve()
        if os.path.exists(index):
            with open(index, "r") as f:
                resp.html = f.read()
        else:
            resp.status_code = status_codes.HTTP_404
            resp.text = "Not found."

    def schema_response(self, req, resp):
        resp.status_code = status_codes.HTTP_200
        resp.headers["Content-Type"] = "application/x-yaml"
        resp.content = self.openapi

    def redirect(self,
                 resp,
                 location,
                 *,
                 set_text=True,
                 status_code=status_codes.HTTP_301):
        """Redirects a given response to a given location.

        :param resp: The Response to mutate.
        :param location: The location of the redirect.
        :param set_text: If ``True``, sets the Redirect body content automatically.
        :param status_code: an `API.status_codes` attribute, or an integer, representing the HTTP status code of the redirect.
        """

        # assert resp.status_code.is_300(status_code)

        resp.status_code = status_code
        if set_text:
            resp.text = f"Redirecting to: {location}"
        resp.headers.update({"Location": location})

    def on_event(self, event_type: str, **args):
        """Decorator for registering functions or coroutines to run at certain events
        Supported events: startup, cleanup, shutdown, tick

        Usage::

            @api.on_event('startup')
            async def open_database_connection_pool():
                ...

            @api.on_event('tick', seconds=10)
            async def do_stuff():
                ...

            @api.on_event('cleanup')
            async def close_database_connection_pool():
                ...

        """
        def decorator(func):
            self.add_event_handler(event_type, func, **args)
            return func

        return decorator

    def route(self, route=None, **options):
        """Decorator for creating new routes around function and class definitions.

        Usage::

            @api.route("/hello")
            def hello(req, resp):
                resp.text = "hello, world!"

        """
        def decorator(f):
            self.add_route(route, f, **options)
            return f

        return decorator

    def mount(self, route, app):
        """Mounts an WSGI / ASGI application at a given route.

        :param route: String representation of the route to be used (shouldn't be parameterized).
        :param app: The other WSGI / ASGI app.
        """
        self.apps.update({route: app})

    def session(self, base_url="http://;"):
        """Testing HTTP client. Returns a Requests session object, able to send HTTP requests to the Responder application.

        :param base_url: The URL to mount the connection adaptor to.
        """

        if self._session is None:
            self._session = TestClient(self, base_url=base_url)
        return self._session

    def _route_for(self, endpoint):
        for route_object in self.routes.values():
            if endpoint in (route_object.endpoint, route_object.endpoint_name):
                return route_object

    def url_for(self, endpoint, **params):
        # TODO: Absolute_url
        """Given an endpoint, returns a rendered URL for its route.

        :param endpoint: The route endpoint you're searching for.
        :param params: Data to pass into the URL generator (for parameterized URLs).
        """
        route_object = self._route_for(endpoint)
        if route_object:
            return route_object.url(**params)
        raise ValueError

    def static_url(self, asset):
        """Given a static asset, return its URL path."""
        assert None not in (self.static_dir, self.static_route)
        return f"{self.static_route}/{str(asset)}"

    @property
    def docs(self):

        loader = jinja2.PrefixLoader({
            self.docs_theme:
            jinja2.PackageLoader(
                "apistar", os.path.join("themes", self.docs_theme,
                                        "templates"))
        })
        env = jinja2.Environment(autoescape=True, loader=loader)
        document = apistar.document.Document()
        document.content = yaml.safe_load(self.openapi)

        template = env.get_template("/".join([self.docs_theme, "index.html"]))

        def static_url(asset):
            assert None not in (self.static_dir, self.static_route)
            return f"{self.static_route}/{asset}"

        return template.render(
            document=document,
            langs=["javascript", "python"],
            code_style=None,
            static_url=static_url,
            schema_url="/schema.yml",
        )

    def template(self, name_, **values):
        """Renders the given `jinja2 <http://jinja.pocoo.org/docs/>`_ template, with provided values supplied.

        Note: The current ``api`` instance is by default passed into the view. This is set in the dict ``api.jinja_values_base``.

        :param name_: The filename of the jinja2 template, in ``templates_dir``.
        :param values: Data to pass into the template.
        """
        # Prepopulate values with base
        values = {**self.jinja_values_base, **values}

        template = self.jinja_env.get_template(name_)
        return template.render(**values)

    def template_string(self, s_, **values):
        """Renders the given `jinja2 <http://jinja.pocoo.org/docs/>`_ template string, with provided values supplied.

        Note: The current ``api`` instance is by default passed into the view. This is set in the dict ``api.jinja_values_base``.

        :param s_: The template to use.
        :param values: Data to pass into the template.
        """
        # Prepopulate values with base
        values = {**self.jinja_values_base, **values}

        template = self.jinja_env.from_string(s_)
        return template.render(**values)

    def serve(self, *, address=None, port=None, debug=False, **options):
        """Runs the application with uvicorn. If the ``PORT`` environment
        variable is set, requests will be served on that port automatically to all
        known hosts.

        :param address: The address to bind to.
        :param port: The port to bind to. If none is provided, one will be selected at random.
        :param debug: Run uvicorn server in debug mode.
        :param options: Additional keyword arguments to send to ``uvicorn.run()``.
        """

        if "PORT" in os.environ:
            if address is None:
                address = "0.0.0.0"
            port = int(os.environ["PORT"])

        if address is None:
            address = "127.0.0.1"
        if port is None:
            port = 5042

        def spawn():
            uvicorn.run(self, host=address, port=port, debug=debug, **options)

        spawn()

    def run(self, **kwargs):
        if "debug" not in kwargs:
            kwargs.update({"debug": self.debug})
        self.serve(**kwargs)
Exemplo n.º 6
0
class Starlette:
    def __init__(self,
                 debug: bool = False,
                 template_directory: str = None) -> None:
        self._debug = debug
        self.router = Router()
        self.exception_middleware = ExceptionMiddleware(self.router,
                                                        debug=debug)
        self.error_middleware = ServerErrorMiddleware(
            self.exception_middleware, debug=debug)
        self.lifespan_middleware = LifespanMiddleware(self.error_middleware)
        self.schema_generator = None  # type: typing.Optional[BaseSchemaGenerator]
        self.template_env = self.load_template_env(template_directory)

    @property
    def routes(self) -> typing.List[BaseRoute]:
        return self.router.routes

    @property
    def debug(self) -> bool:
        return self._debug

    @debug.setter
    def debug(self, value: bool) -> None:
        self._debug = value
        self.exception_middleware.debug = value
        self.error_middleware.debug = value

    def load_template_env(self, template_directory: str = None) -> typing.Any:
        if template_directory is None:
            return None

        # Import jinja2 lazily.
        import jinja2

        @jinja2.contextfunction
        def url_for(context: dict, name: str,
                    **path_params: typing.Any) -> str:
            request = context["request"]
            return request.url_for(name, **path_params)

        loader = jinja2.FileSystemLoader(str(template_directory))
        env = jinja2.Environment(loader=loader, autoescape=True)
        env.globals["url_for"] = url_for
        return env

    def get_template(self, name: str) -> typing.Any:
        return self.template_env.get_template(name)

    @property
    def schema(self) -> dict:
        assert self.schema_generator is not None
        return self.schema_generator.get_schema(self.routes)

    def on_event(self, event_type: str) -> typing.Callable:
        return self.lifespan_middleware.on_event(event_type)

    def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
        self.router.mount(path, app=app, name=name)

    def host(self, host: str, app: ASGIApp, name: str = None) -> None:
        self.router.host(host, app=app, name=name)

    def add_middleware(self, middleware_class: type,
                       **kwargs: typing.Any) -> None:
        self.error_middleware.app = middleware_class(self.error_middleware.app,
                                                     **kwargs)

    def add_exception_handler(
        self,
        exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
        handler: typing.Callable,
    ) -> None:
        if exc_class_or_status_code in (500, Exception):
            self.error_middleware.handler = handler
        else:
            self.exception_middleware.add_exception_handler(
                exc_class_or_status_code, handler)

    def add_event_handler(self, event_type: str,
                          func: typing.Callable) -> None:
        self.lifespan_middleware.add_event_handler(event_type, func)

    def add_route(
        self,
        path: str,
        route: typing.Callable,
        methods: typing.List[str] = None,
        name: str = None,
        include_in_schema: bool = True,
    ) -> None:
        self.router.add_route(path,
                              route,
                              methods=methods,
                              name=name,
                              include_in_schema=include_in_schema)

    def add_websocket_route(self,
                            path: str,
                            route: typing.Callable,
                            name: str = None) -> None:
        self.router.add_websocket_route(path, route, name=name)

    def exception_handler(
        self, exc_class_or_status_code: typing.Union[int,
                                                     typing.Type[Exception]]
    ) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.add_exception_handler(exc_class_or_status_code, func)
            return func

        return decorator

    def route(
        self,
        path: str,
        methods: typing.List[str] = None,
        name: str = None,
        include_in_schema: bool = True,
    ) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.router.add_route(
                path,
                func,
                methods=methods,
                name=name,
                include_in_schema=include_in_schema,
            )
            return func

        return decorator

    def websocket_route(self, path: str, name: str = None) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.router.add_websocket_route(path, func, name=name)
            return func

        return decorator

    def middleware(self, middleware_type: str) -> typing.Callable:
        assert (middleware_type == "http"
                ), 'Currently only middleware("http") is supported.'

        def decorator(func: typing.Callable) -> typing.Callable:
            self.add_middleware(BaseHTTPMiddleware, dispatch=func)
            return func

        return decorator

    def url_path_for(self, name: str, **path_params: str) -> URLPath:
        return self.router.url_path_for(name, **path_params)

    def __call__(self, scope: Scope) -> ASGIInstance:
        scope["app"] = self
        return self.lifespan_middleware(scope)
Exemplo n.º 7
0
def test_raise_on_shutdown():
    handler = LifespanMiddleware(RaiseOnShutdown)

    with pytest.raises(RuntimeError):
        with TestClient(handler):
            pass
Exemplo n.º 8
0
def test_raise_on_startup():
    handler = LifespanMiddleware(RaiseOnStartup)

    with pytest.raises(RuntimeError):
        with TestClient(handler):
            pass  # pragma: nocover
Exemplo n.º 9
0
 def _get_lifespan_middleware(self, app: ASGIApp):
     middleware = LifespanMiddleware(app)
     for event, func in self._events:
         middleware.add_event_handler(event, func)
     return middleware
Exemplo n.º 10
0
    def __init__(
        self,
        templates_dir: str = "templates",
        static_dir: Optional[str] = "static",
        static_root: Optional[str] = "static",
        allowed_hosts: List[str] = None,
        enable_cors: bool = False,
        cors_config: dict = None,
        enable_hsts: bool = False,
        enable_gzip: bool = False,
        gzip_min_size: int = 1024,
        media_type: Optional[str] = Media.JSON,
    ):
        super().__init__(templates_dir=templates_dir)

        # Debug mode defaults to `False` but it can be set in `.run()`.
        self._debug = False

        # Base ASGI app
        self.asgi = self.dispatch

        # Mounted apps
        self.apps: Dict[str, Any] = {}

        # Routers
        self.http_router = HTTPRouter()
        self.websocket_router = WebSocketRouter()

        # Test client
        self.client = self.build_client()

        # Static files
        if static_dir is not None:
            if static_root is None:
                static_root = static_dir
            self.mount(static_root, static(static_dir))

        # Media handlers
        self._media = Media(media_type=media_type)

        # HTTP middleware
        self.exception_middleware = HTTPErrorMiddleware(self.http_router,
                                                        debug=self._debug)
        self.server_error_middleware = ServerErrorMiddleware(
            self.exception_middleware,
            handler=error_to_text,
            debug=self._debug)
        self.add_error_handler(HTTPError, error_to_text)

        # Lifespan middleware
        self.lifespan_middleware = LifespanMiddleware(self.dispatch_lifespan)

        # ASGI middleware
        if allowed_hosts is None:
            allowed_hosts = ["*"]
        self.add_asgi_middleware(TrustedHostMiddleware,
                                 allowed_hosts=allowed_hosts)
        if enable_cors:
            cors_config = {**DEFAULT_CORS_CONFIG, **(cors_config or {})}
            self.add_asgi_middleware(CORSMiddleware, **cors_config)
        if enable_hsts:
            self.add_asgi_middleware(HTTPSRedirectMiddleware)
        if enable_gzip:
            self.add_asgi_middleware(GZipMiddleware,
                                     minimum_size=gzip_min_size)
Exemplo n.º 11
0
class API(TemplatesMixin, metaclass=DocsMeta):
    """The all-mighty API class.

    This class implements the [ASGI](https://asgi.readthedocs.io) protocol.

    # Example

    ```python
    >>> import bocadillo
    >>> api = bocadillo.API()
    ```

    # Parameters

    templates_dir (str):
        The name of the directory where templates are searched for,
        relative to the application entry point.
        Defaults to `"templates"`.
    static_dir (str):
        The name of the directory containing static files, relative to
        the application entry point. Set to `None` to not serve any static
        files.
        Defaults to `"static"`.
    static_root (str):
        The path prefix for static assets.
        Defaults to `"static"`.
    allowed_hosts (list of str, optional):
        A list of hosts which the server is allowed to run at.
        If the list contains `"*"`, any host is allowed.
        Defaults to `["*"]`.
    enable_cors (bool):
        If `True`, Cross Origin Resource Sharing will be configured according
        to `cors_config`. Defaults to `False`.
        See also [CORS](../guides/http/middleware.md#cors).
    cors_config (dict):
        A dictionary of CORS configuration parameters.
        Defaults to `dict(allow_origins=[], allow_methods=["GET"])`.
    enable_hsts (bool):
        If `True`, enable HSTS (HTTP Strict Transport Security) and automatically
        redirect HTTP traffic to HTTPS.
        Defaults to `False`.
        See also [HSTS](../guides/http/middleware.md#hsts).
    enable_gzip (bool):
        If `True`, enable GZip compression and automatically
        compress responses for clients that support it.
        Defaults to `False`.
        See also [GZip](../guides/http/middleware.md#gzip).
    gzip_min_size (int):
        If specified, compress only responses that
        have more bytes than the specified value.
        Defaults to `1024`.
    media_type (str):
        Determines how values given to `res.media` are serialized.
        Can be one of the supported media types.
        Defaults to `"application/json"`.
        See also [Media](../guides/http/media.md).
    """
    def __init__(
        self,
        templates_dir: str = "templates",
        static_dir: Optional[str] = "static",
        static_root: Optional[str] = "static",
        allowed_hosts: List[str] = None,
        enable_cors: bool = False,
        cors_config: dict = None,
        enable_hsts: bool = False,
        enable_gzip: bool = False,
        gzip_min_size: int = 1024,
        media_type: Optional[str] = Media.JSON,
    ):
        super().__init__(templates_dir=templates_dir)

        # Debug mode defaults to `False` but it can be set in `.run()`.
        self._debug = False

        # Base ASGI app
        self.asgi = self.dispatch

        # Mounted apps
        self.apps: Dict[str, Any] = {}

        # Routers
        self.http_router = HTTPRouter()
        self.websocket_router = WebSocketRouter()

        # Test client
        self.client = self.build_client()

        # Static files
        if static_dir is not None:
            if static_root is None:
                static_root = static_dir
            self.mount(static_root, static(static_dir))

        # Media handlers
        self._media = Media(media_type=media_type)

        # HTTP middleware
        self.exception_middleware = HTTPErrorMiddleware(self.http_router,
                                                        debug=self._debug)
        self.server_error_middleware = ServerErrorMiddleware(
            self.exception_middleware,
            handler=error_to_text,
            debug=self._debug)
        self.add_error_handler(HTTPError, error_to_text)

        # Lifespan middleware
        self.lifespan_middleware = LifespanMiddleware(self.dispatch_lifespan)

        # ASGI middleware
        if allowed_hosts is None:
            allowed_hosts = ["*"]
        self.add_asgi_middleware(TrustedHostMiddleware,
                                 allowed_hosts=allowed_hosts)
        if enable_cors:
            cors_config = {**DEFAULT_CORS_CONFIG, **(cors_config or {})}
            self.add_asgi_middleware(CORSMiddleware, **cors_config)
        if enable_hsts:
            self.add_asgi_middleware(HTTPSRedirectMiddleware)
        if enable_gzip:
            self.add_asgi_middleware(GZipMiddleware,
                                     minimum_size=gzip_min_size)

    @property
    def debug(self) -> bool:
        return self._debug

    @debug.setter
    def debug(self, debug: bool):
        self._debug = debug
        self.exception_middleware.debug = debug
        self.server_error_middleware.debug = debug

    def build_client(self, **kwargs) -> TestClient:
        return TestClient(self, **kwargs)

    def get_template_globals(self):
        """Return global variables available to all templates.

        # Returns
        variables (dict): a mapping of variable names to their values.
        """
        return {"url_for": self.url_for}

    def mount(self, prefix: str, app: Union[ASGIApp, WSGIApp]):
        """Mount another WSGI or ASGI app at the given prefix.

        # Parameters
        prefix (str): A path prefix where the app should be mounted, e.g. `"/myapp"`.
        app: An object implementing [WSGI](https://wsgi.readthedocs.io) or [ASGI](https://asgi.readthedocs.io) protocol.
        """
        if not prefix.startswith("/"):
            prefix = "/" + prefix
        self.apps[prefix] = app

    def recipe(self, recipe: RecipeBase):
        recipe(self)

    @property
    def media_type(self) -> str:
        """The currently configured media type.

        When setting it to a value outside of built-in or custom media types,
        an `UnsupportedMediaType` exception is raised.
        """
        return self._media.type

    @media_type.setter
    def media_type(self, media_type: str):
        self._media.type = media_type

    @property
    def media_handlers(self) -> dict:
        """The dictionary of supported media handlers.

        You can access, edit or replace this at will.
        """
        return self._media.handlers

    @media_handlers.setter
    def media_handlers(self, media_handlers: dict):
        self._media.handlers = media_handlers

    def add_error_handler(self, exception_cls: Type[Exception],
                          handler: ErrorHandler):
        """Register a new error handler.

        # Parameters
        exception_cls (Exception class):
            The type of exception that should be handled.
        handler (callable):
            The actual error handler, which is called when an instance of
            `exception_cls` is caught.
            Should accept a request, response and exception parameters.
        """
        self.exception_middleware.add_exception_handler(exception_cls, handler)

    def error_handler(self, exception_cls: Type[Exception]):
        """Register a new error handler (decorator syntax).

        # See Also
        - [add_error_handler](#add-error-handler)
        """
        def wrapper(handler):
            self.add_error_handler(exception_cls, handler)
            return handler

        return wrapper

    def route(self, pattern: str, *, name: str = None, namespace: str = None):
        """Register a new route by decorating a view.

        # Parameters
        pattern (str): an URL pattern.
        methods (list of str):
            An optional list of HTTP methods.
            Defaults to `["get", "head"]`.
            Ignored for class-based views.
        name (str):
            An optional name for the route.
            If a route already exists for this name, it is replaced.
            Defaults to a snake-cased version of the view's name.
        namespace (str):
            An optional namespace for the route. If given, it is prefixed to
            the name and separated by a colon.

        # See Also
        - [check_route](#check-route) for the route validation algorithm.
        """
        return self.http_router.route(pattern=pattern,
                                      name=name,
                                      namespace=namespace)

    def websocket_route(
        self,
        pattern: str,
        *,
        value_type: Optional[str] = None,
        receive_type: Optional[str] = None,
        send_type: Optional[str] = None,
        caught_close_codes: Optional[Tuple[int]] = None,
    ):
        """Register a WebSocket route by decorating a view.

        # Parameters
        pattern (str): an URL pattern.

        # See Also
        - [WebSocket](./websockets.md#websocket) for a description of keyword
        arguments.
        """
        # NOTE: use named keyword arguments instead of `**kwargs` to improve
        # their accessibility (e.g. for IDE discovery).
        return self.websocket_router.route(
            pattern,
            value_type=value_type,
            receive_type=receive_type,
            send_type=send_type,
            caught_close_codes=caught_close_codes,
        )

    def url_for(self, name: str, **kwargs) -> str:
        """Build the URL path for a named route.

        # Parameters
        name (str): the name of the route.
        kwargs (dict): route parameters.

        # Returns
        url (str): the URL path for a route.

        # Raises
        HTTPError(404) : if no route exists for the given `name`.
        """
        route = self.http_router.routes.get(name)
        if route is None:
            raise HTTPError(404)
        return route.url(**kwargs)

    def redirect(
        self,
        *,
        name: str = None,
        url: str = None,
        permanent: bool = False,
        **kwargs,
    ):
        """Redirect to another HTTP route.

        # Parameters
        name (str): name of the route to redirect to.
        url (str):
            URL of the route to redirect to (required if `name` is omitted).
        permanent (bool):
            If `False` (the default), returns a temporary redirection (302).
            If `True`, returns a permanent redirection (301).
        kwargs (dict):
            Route parameters.

        # Raises
        Redirection:
            an exception that will be caught to trigger a redirection.

        # See Also
        - [Redirecting](../guides/http/redirecting.md)
        """
        if name is not None:
            url = self.url_for(name=name, **kwargs)
        else:
            assert url is not None, "url is expected if no route name is given"
        raise Redirection(url=url, permanent=permanent)

    def add_middleware(self, middleware_cls, **kwargs):
        """Register a middleware class.

        # Parameters

        middleware_cls (Middleware class):
            A subclass of `bocadillo.Middleware`.

        # See Also
        - [Middleware](../guides/http/middleware.md)
        """
        self.exception_middleware.app = middleware_cls(
            self.exception_middleware.app, **kwargs)

    def add_asgi_middleware(self, middleware_cls, *args, **kwargs):
        """Register an ASGI middleware class.

        # Parameters
        middleware_cls (Middleware class):
            A class that complies with the ASGI specification.

        # See Also
        - [ASGI middleware](../guides/agnostic/asgi-middleware.md)
        - [ASGI](https://asgi.readthedocs.io)
        """
        self.asgi = middleware_cls(self.asgi, *args, **kwargs)

    def on(self, event: str, handler: Optional[EventHandler] = None):
        """Register an event handler.

        # Parameters
        event (str):
            Either `"startup"` (when the server boots) or `"shutdown"`
            (when the server stops).
        handler (callback, optional):
            The event handler. If not given, this should be used as a
            decorator.

        # Example

        ```python
        @api.on("startup")
        async def startup():
            pass

        async def shutdown():
            pass

        api.on("shutdown", shutdown)
        ```
        """
        if handler is None:

            def register(func):
                self.lifespan_middleware.add_event_handler(event, func)
                return func

            return register
        else:
            self.lifespan_middleware.add_event_handler(event, handler)
            return handler

    def dispatch_lifespan(self, scope: Scope):
        # Strict implementation of the ASGI lifespan spec.
        # This is required because the Starlette `LifespanMiddleware`
        # does not send the `complete` responses.

        async def asgi(receive, send):
            message = await receive()
            assert message["type"] == "lifespan.startup"
            await send({"type": "lifespan.startup.complete"})

            message = await receive()
            assert message["type"] == "lifespan.shutdown"
            await send({"type": "lifespan.shutdown.complete"})

        return asgi

    async def dispatch_http(self, receive: Receive, send: Send, scope: Scope):
        assert scope["type"] == "http"
        req = Request(scope, receive)
        res = Response(req, media=self._media)
        res = await self.server_error_middleware(req, res)
        await res(receive, send)
        # Re-raise the exception to allow the server to log the error
        # and for the test client to optionally re-raise it too.
        self.server_error_middleware.raise_if_exception()

    async def dispatch_websocket(self, receive: Receive, send: Send,
                                 scope: Scope):
        await self.websocket_router(scope, receive, send)

    def dispatch(self, scope: Scope) -> ASGIAppInstance:
        if scope["type"] == "websocket":
            return partial(self.dispatch_websocket, scope=scope)
        else:
            assert scope["type"] == "http"
            return partial(self.dispatch_http, scope=scope)

    def __call__(self, scope: Scope) -> ASGIAppInstance:
        if scope["type"] == "lifespan":
            return self.lifespan_middleware(scope)

        path: str = scope["path"]

        # Return a sub-mounted extra app, if found
        for prefix, app in self.apps.items():
            if not path.startswith(prefix):
                continue
            # Remove prefix from path so that the request is made according
            # to the mounted app's point of view.
            scope["path"] = path[len(prefix):]
            try:
                return app(scope)
            except TypeError:
                return WSGIResponder(app, scope)

        return self.asgi(scope)

    def run(
        self,
        host: str = None,
        port: int = None,
        debug: bool = False,
        log_level: str = "info",
        _run: Callable = None,
        **kwargs,
    ):
        """Serve the application using [uvicorn](https://www.uvicorn.org).

        # Parameters

        host (str):
            The host to bind to.
            Defaults to `"127.0.0.1"` (localhost).
            If not given and `$PORT` is set, `"0.0.0.0"` will be used to
            serve to all known hosts.
        port (int):
            The port to bind to.
            Defaults to `8000` or (if set) the value of the `$PORT` environment
            variable.
        debug (bool):
            Whether to serve the application in debug mode. Defaults to `False`.
        log_level (str):
            A logging level for the debug logger. Must be a logging level
            from the `logging` module. Defaults to `"info"`.
        kwargs (dict):
            Extra keyword arguments that will be passed to the Uvicorn runner.

        # See Also
        - [Configuring host and port](../guides/api.md#configuring-host-and-port)
        - [Debug mode](../guides/api.md#debug-mode)
        - [Uvicorn settings](https://www.uvicorn.org/settings/) for all
        available keyword arguments.
        """
        if _run is None:  # pragma: no cover
            _run = run

        if "PORT" in os.environ:
            port = int(os.environ["PORT"])
            if host is None:
                host = "0.0.0.0"

        if host is None:
            host = "127.0.0.1"

        if port is None:
            port = 8000

        if debug:
            self.debug = True
            reloader = StatReload(get_logger(log_level))
            reloader.run(
                run,
                {
                    "app": self,
                    "host": host,
                    "port": port,
                    "log_level": log_level,
                    "debug": self.debug,
                    **kwargs,
                },
            )
        else:
            _run(self, host=host, port=port, **kwargs)