コード例 #1
0
 def __init__(self, debug: bool = False) -> None:
     self.router = Router(routes=[])
     self.lifespan_handler = LifespanHandler()
     self.app = self.router
     self.exception_middleware = ExceptionMiddleware(self.router,
                                                     debug=debug)
     self.executor = ThreadPoolExecutor()
コード例 #2
0
 def __init__(self,
              debug: bool = False,
              template_directory: str = None) -> None:
     self._debug = debug
     self.router = Router()
     self.lifespan_handler = LifespanHandler()
     self.exception_middleware = ExceptionMiddleware(self.router,
                                                     debug=debug)
     self.error_middleware = ServerErrorMiddleware(
         self.exception_middleware, debug=debug)
     self.schema_generator = None  # type: typing.Optional[BaseSchemaGenerator]
     self.template_env = self.load_template_env(template_directory)
コード例 #3
0
ファイル: applications.py プロジェクト: bendog/starlette-api
    def __init__(self,
                 components: typing.Optional[typing.List[Component]] = None,
                 debug: bool = False) -> None:
        if components is None:
            components = []

        # Initialize injector
        self.injector = Injector(components)

        self.router = Router()
        self.lifespan_handler = LifespanHandler()
        self.app = self.router
        self.exception_middleware = ExceptionMiddleware(self.router,
                                                        debug=debug)

        # Add exception handler for API exceptions
        self.add_exception_handler(exceptions.HTTPException,
                                   self.api_http_exception_handler)
コード例 #4
0
ファイル: test_lifespan.py プロジェクト: ryankask/starlette
def test_lifespan_handler():
    startup_complete = False
    cleanup_complete = False
    handler = LifespanHandler()

    @handler.on_event("startup")
    def run_startup():
        nonlocal startup_complete
        startup_complete = True

    @handler.on_event("shutdown")
    def run_cleanup():
        nonlocal cleanup_complete
        cleanup_complete = True

    assert not startup_complete
    assert not cleanup_complete
    with LifespanContext(handler):
        assert startup_complete
        assert not cleanup_complete
    assert startup_complete
    assert cleanup_complete
コード例 #5
0
    def __init__(
        self,
        *,
        debug=False,
        title=None,
        version=None,
        openapi=None,
        openapi_route="/schema.yml",
        static_dir="static",
        static_route="/static",
        templates_dir="templates",
        auto_escape=True,
        secret_key=DEFAULT_SECRET_KEY,
        enable_hsts=False,
        docs_route=None,
        cors=False,
    ):

        self.secret_key = secret_key
        self.title = title
        self.version = version
        self.openapi_version = openapi
        self.static_dir = Path(os.path.abspath(static_dir))
        self.static_route = static_route
        self.templates_dir = Path(os.path.abspath(templates_dir))
        self.built_in_templates_dir = Path(
            os.path.abspath(os.path.dirname(__file__) + "/templates")
        )
        self.routes = {}
        self.docs_theme = DEFAULT_API_THEME
        self.docs_route = docs_route
        self.schemas = {}
        self.session_cookie = DEFAULT_SESSION_COOKIE

        self.hsts_enabled = enable_hsts
        self.cors = cors
        self.cors_params = DEFAULT_CORS_PARAMS
        # Make the static/templates directory if they don't exist.
        for _dir in (self.static_dir, self.templates_dir):
            os.makedirs(_dir, exist_ok=True)

        self.whitenoise = WhiteNoise(application=self._default_wsgi_app)
        self.whitenoise.add_files(str(self.static_dir))

        self.whitenoise.add_files(
            (
                Path(apistar.__file__).parent / "themes" / self.docs_theme / "static"
            ).resolve()
        )

        self.apps = {}
        self.mount(self.static_route, self.whitenoise)

        self.formats = get_formats()

        # Cached requests session.
        self._session = None
        self.background = BackgroundQueue()

        if self.openapi_version:
            self.add_route(openapi_route, self.schema_response)

        if self.docs_route:
            self.add_route(self.docs_route, self.docs_response)

        self.default_endpoint = None
        self.app = self.dispatch
        self.add_middleware(GZipMiddleware)
        if debug:
            self.add_middleware(DebugMiddleware)

        if self.hsts_enabled:
            self.add_middleware(HTTPSRedirectMiddleware)
        self.lifespan_handler = LifespanHandler()

        if self.cors:
            self.add_middleware(CORSMiddleware, **self.cors_params)
        self.add_middleware(ExceptionMiddleware, debug=debug)

        # Jinja enviroment
        self.jinja_env = jinja2.Environment(
            loader=jinja2.FileSystemLoader(
                [str(self.templates_dir), str(self.built_in_templates_dir)],
                followlinks=True,
            ),
            autoescape=jinja2.select_autoescape(["html", "xml"] if auto_escape else []),
        )
        self.jinja_values_base = {"api": self}  # Give reference to self.
        self.requests = (
            self.session()
        )  #: A Requests session that is connected to the ASGI app.
コード例 #6
0
class API:
    """The primary web-service class.

        :param static_dir: The directory to use for static files. Will be created for you if it doesn't already exist.
        :param templates_dir: The directory to use for templates. Will be created for you if it doesn't already exist.
        :param auto_escape: If ``True``, HTML and XML templates will automatically be escaped.
        :param enable_hsts: If ``True``, send all responses to HTTPS URLs.
    """

    status_codes = status_codes

    def __init__(
        self,
        *,
        debug=False,
        title=None,
        version=None,
        openapi=None,
        openapi_route="/schema.yml",
        static_dir="static",
        static_route="/static",
        templates_dir="templates",
        auto_escape=True,
        secret_key=DEFAULT_SECRET_KEY,
        enable_hsts=False,
        docs_route=None,
        cors=False,
    ):

        self.secret_key = secret_key
        self.title = title
        self.version = version
        self.openapi_version = openapi
        self.static_dir = Path(os.path.abspath(static_dir))
        self.static_route = static_route
        self.templates_dir = Path(os.path.abspath(templates_dir))
        self.built_in_templates_dir = Path(
            os.path.abspath(os.path.dirname(__file__) + "/templates")
        )
        self.routes = {}
        self.docs_theme = DEFAULT_API_THEME
        self.docs_route = docs_route
        self.schemas = {}
        self.session_cookie = DEFAULT_SESSION_COOKIE

        self.hsts_enabled = enable_hsts
        self.cors = cors
        self.cors_params = DEFAULT_CORS_PARAMS
        # Make the static/templates directory if they don't exist.
        for _dir in (self.static_dir, self.templates_dir):
            os.makedirs(_dir, exist_ok=True)

        self.whitenoise = WhiteNoise(application=self._default_wsgi_app)
        self.whitenoise.add_files(str(self.static_dir))

        self.whitenoise.add_files(
            (
                Path(apistar.__file__).parent / "themes" / self.docs_theme / "static"
            ).resolve()
        )

        self.apps = {}
        self.mount(self.static_route, self.whitenoise)

        self.formats = get_formats()

        # Cached requests session.
        self._session = None
        self.background = BackgroundQueue()

        if self.openapi_version:
            self.add_route(openapi_route, self.schema_response)

        if self.docs_route:
            self.add_route(self.docs_route, self.docs_response)

        self.default_endpoint = None
        self.app = self.dispatch
        self.add_middleware(GZipMiddleware)
        if debug:
            self.add_middleware(DebugMiddleware)

        if self.hsts_enabled:
            self.add_middleware(HTTPSRedirectMiddleware)
        self.lifespan_handler = LifespanHandler()

        if self.cors:
            self.add_middleware(CORSMiddleware, **self.cors_params)
        self.add_middleware(ExceptionMiddleware, debug=debug)

        # Jinja enviroment
        self.jinja_env = jinja2.Environment(
            loader=jinja2.FileSystemLoader(
                [str(self.templates_dir), str(self.built_in_templates_dir)],
                followlinks=True,
            ),
            autoescape=jinja2.select_autoescape(["html", "xml"] if auto_escape else []),
        )
        self.jinja_values_base = {"api": self}  # Give reference to self.
        self.requests = (
            self.session()
        )  #: A Requests session that is connected to the ASGI app.

    @staticmethod
    def _default_wsgi_app(*args, **kwargs):
        pass

    @property
    def _apispec(self):
        spec = APISpec(
            title=self.title,
            version=self.version,
            openapi_version=self.openapi_version,
            plugins=[MarshmallowPlugin()],
        )

        for route in self.routes:
            if self.routes[route].description:
                operations = yaml_utils.load_operations_from_docstring(
                    self.routes[route].description
                )
                spec.add_path(path=route, operations=operations)

        for name, schema in self.schemas.items():
            spec.definition(name, schema=schema)

        return spec

    @property
    def openapi(self):
        return self._apispec.to_yaml()

    def add_middleware(self, middleware_cls, **middleware_config):
        self.app = middleware_cls(self.app, **middleware_config)

    def __call__(self, scope):
        if scope["type"] == "lifespan":
            return self.lifespan_handler(scope)

        path = scope["path"]
        root_path = scope.get("root_path", "")

        # Call into a submounted app, if one exists.
        for path_prefix, app in self.apps.items():
            if path.startswith(path_prefix):
                scope["path"] = path[len(path_prefix) :]
                scope["root_path"] = root_path + path_prefix
                try:
                    return app(scope)
                except TypeError:
                    app = WsgiToAsgi(app)
                    return app(scope)

        return self.app(scope)

    def dispatch(self, scope):
        # Call the main dispatcher.
        async def asgi(receive, send):
            nonlocal scope, self

            req = models.Request(scope, receive=receive, api=self)
            resp = await self._dispatch_request(
                req, scope=scope, send=send, receive=receive
            )
            await resp(receive, send)

        return asgi

    def add_schema(self, name, schema, check_existing=True):
        """Adds a mashmallow schema to the API specification."""
        if check_existing:
            assert name not in self.schemas

        self.schemas[name] = schema

    def schema(self, name, **options):
        """Decorator for creating new routes around function and class definitions.

        Usage::

            from marshmallow import Schema, fields

            @api.schema("Pet")
            class PetSchema(Schema):
                name = fields.Str()

        """

        def decorator(f):
            self.add_schema(name=name, schema=f, **options)
            return f

        return decorator

    def path_matches_route(self, path):
        """Given a path portion of a URL, tests that it matches against any registered route.

        :param path: The path portion of a URL, to test all known routes against.
        """
        for (route, route_object) in self.routes.items():
            if route_object.does_match(path):
                return route

    def _prepare_cookies(self, resp):
        if resp.cookies:
            header = " ".join([f"{k}={v};" for k, v in resp.cookies.items()])
            resp.headers["Set-Cookie"] = header

    @property
    def _signer(self):
        return itsdangerous.Signer(self.secret_key)

    def _prepare_session(self, resp):

        if resp.session:
            data = self._signer.sign(
                b64encode(json.dumps(resp.session).encode("utf-8"))
            )
            resp.cookies[self.session_cookie] = data.decode("utf-8")

    @staticmethod
    def no_response(req, resp, **params):
        pass

    async def _dispatch_request(self, req, **options):
        # Set formats on Request object.
        req.formats = self.formats

        # Get the route.
        route = self.path_matches_route(req.url.path)
        route = self.routes.get(route)

        # Create the response object.
        cont = False
        if route:
            if route.uses_websocket:
                resp = WebSocket(**options)

            else:
                resp = models.Response(req=req, formats=self.formats)

            params = route.incoming_matches(req.url.path)

            if route.is_function:
                try:
                    try:
                        # Run the view.
                        r = route.endpoint(req, resp, **params)
                        # If it's async, await it.
                        if hasattr(r, "cr_running"):
                            await r
                    except TypeError as e:
                        cont = True
                except Exception:
                    self.default_response(req, resp, error=True)
                    raise

            elif route.is_class_based or cont:
                try:
                    view = route.endpoint(**params)
                except TypeError:
                    try:
                        view = route.endpoint()
                    except TypeError:
                        view = route.endpoint

                # Run on_request first.
                try:
                    # Run the view.
                    r = getattr(view, "on_request", self.no_response)(
                        req, resp, **params
                    )
                    # If it's async, await it.
                    if hasattr(r, "send"):
                        await r
                except Exception:
                    self.default_response(req, resp, error=True)
                    raise

                # Then run on_method.
                method = req.method
                try:
                    # Run the view.
                    r = getattr(view, f"on_{method}", self.no_response)(
                        req, resp, **params
                    )
                    # If it's async, await it.
                    if hasattr(r, "send"):
                        await r
                except Exception as e:

                    self.default_response(req, resp, error=True)

        else:
            resp = models.Response(req=req, formats=self.formats)
            self.default_response(req, resp, notfound=True)
        self.default_response(req, resp)

        self._prepare_session(resp)
        self._prepare_cookies(resp)

        return resp

    def add_event_handler(self, event_type, handler):
        """Adds an event handler to the API.

        :param event_type: A string in ("startup", "shutdown")
        :param handler: The function to run. Can be either a function or a coroutine.
        """

        self.lifespan_handler.add_event_handler(event_type, handler)

    def add_route(
        self,
        route,
        endpoint=None,
        *,
        default=False,
        static=False,
        check_existing=True,
        websocket=False,
    ):
        """Adds a route to the API.

        :param route: A string representation of the route.
        :param endpoint: The endpoint for the route -- can be a callable, or a class.
        :param default: If ``True``, all unknown requests will route to this view.
        :param static: If ``True``, and no endpoint was passed, render "static/index.html", and it will become a default route.
        :param check_existing: If ``True``, an AssertionError will be raised, if the route is already defined.
        """
        if check_existing:
            assert route not in self.routes

        if not endpoint and static:
            endpoint = self.static_response
            default = True

        if default:
            self.default_endpoint = endpoint

        self.routes[route] = Route(route, endpoint, websocket=websocket)
        # TODO: A better data structure or sort it once the app is loaded
        self.routes = dict(
            sorted(self.routes.items(), key=lambda item: item[1]._weight())
        )

    def default_response(self, req, resp, notfound=False, error=False):
        if resp.status_code is None:
            resp.status_code = 200

        if self.default_endpoint and notfound:
            self.default_endpoint(req, resp)
        else:
            if notfound:
                resp.status_code = status_codes.HTTP_404
                resp.text = "Not found."
            if error:
                resp.status_code = status_codes.HTTP_500
                resp.text = "Application error."

    def docs_response(self, req, resp):
        resp.text = self.docs

    def static_response(self, req, resp):
        index = (self.static_dir / "index.html").resolve()
        resp.content = ""
        if os.path.exists(index):
            with open(index, "r") as f:
                resp.text = f.read()

    def schema_response(self, req, resp):
        resp.status_code = status_codes.HTTP_200
        resp.headers["Content-Type"] = "application/x-yaml"
        resp.content = self.openapi

    def redirect(
        self, resp, location, *, set_text=True, status_code=status_codes.HTTP_301
    ):
        """Redirects a given response to a given location.

        :param resp: The Response to mutate.
        :param location: The location of the redirect.
        :param set_text: If ``True``, sets the Redirect body content automatically.
        :param status_code: an `API.status_codes` attribute, or an integer, representing the HTTP status code of the redirect.
        """

        # assert resp.status_code.is_300(status_code)

        resp.status_code = status_code
        if set_text:
            resp.text = f"Redirecting to: {location}"
        resp.headers.update({"Location": location})

    def on_event(self, event_type: str, **args):
        """Decorator for registering functions or coroutines to run at certain events
        Supported events: startup, cleanup, shutdown, tick

        Usage::

            @api.on_event('startup')
            async def open_database_connection_pool():
                ...

            @api.on_event('tick', seconds=10)
            async def do_stuff():
                ...

            @api.on_event('cleanup')
            async def close_database_connection_pool():
                ...

        """

        def decorator(func):
            self.add_event_handler(event_type, func, **args)
            return func

        return decorator

    def route(self, route, **options):
        """Decorator for creating new routes around function and class definitions.

        Usage::

            @api.route("/hello")
            def hello(req, resp):
                resp.text = "hello, world!"

        """

        def decorator(f):
            self.add_route(route, f, **options)
            return f

        return decorator

    def mount(self, route, app):
        """Mounts an WSGI / ASGI application at a given route.

        :param route: String representation of the route to be used (shouldn't be parameterized).
        :param app: The other WSGI / ASGI app.
        """
        self.apps.update({route: app})

    def session(self, base_url="http://;"):
        """Testing HTTP client. Returns a Requests session object, able to send HTTP requests to the Responder application.

        :param base_url: The URL to mount the connection adaptor to.
        """

        if self._session is None:
            self._session = TestClient(self)
        return self._session

    def _route_for(self, endpoint):
        for (route, route_object) in self.routes.items():
            if route_object.endpoint == endpoint:
                return route_object
            elif route_object.endpoint_name == endpoint:
                return route_object

    def url_for(self, endpoint, **params):
        # TODO: Absolute_url
        """Given an endpoint, returns a rendered URL for its route.

        :param view: The route endpoint you're searching for.
        :param params: Data to pass into the URL generator (for parameterized URLs).
        """
        route_object = self._route_for(endpoint)
        if route_object:
            return route_object.url(**params)
        raise ValueError

    def static_url(self, asset):
        """Given a static asset, return its URL path."""
        return f"{self.static_route}/{str(asset)}"

    @property
    def docs(self):

        loader = jinja2.PrefixLoader(
            {
                self.docs_theme: jinja2.PackageLoader(
                    "apistar", os.path.join("themes", self.docs_theme, "templates")
                )
            }
        )
        env = jinja2.Environment(autoescape=True, loader=loader)
        document = apistar.document.Document()
        document.content = yaml.safe_load(self.openapi)

        template = env.get_template("/".join([self.docs_theme, "index.html"]))

        def static_url(asset):
            return f"{self.static_route}/{asset}"
            # return asset

        return template.render(
            document=document,
            langs=["javascript", "python"],
            code_style=None,
            static_url=static_url,
            schema_url="/schema.yml",
        )

    def template(self, name_, **values):
        """Renders the given `jinja2 <http://jinja.pocoo.org/docs/>`_ template, with provided values supplied.

        Note: The current ``api`` instance is by default passed into the view. This is set in the dict ``api.jinja_values_base``.

        :param name_: The filename of the jinja2 template, in ``templates_dir``.
        :param values: Data to pass into the template.
        """
        # Prepopulate values with base
        values = {**self.jinja_values_base, **values}

        template = self.jinja_env.get_template(name_)
        return template.render(**values)

    def template_string(self, s_, **values):
        """Renders the given `jinja2 <http://jinja.pocoo.org/docs/>`_ template string, with provided values supplied.

        Note: The current ``api`` instance is by default passed into the view. This is set in the dict ``api.jinja_values_base``.

        :param s_: The template to use.
        :param values: Data to pass into the template.
        """
        # Prepopulate values with base
        values = {**self.jinja_values_base, **values}

        template = self.jinja_env.from_string(s_)
        return template.render(**values)

    def run(self, address=None, port=None, debug=False, **options):
        """Runs the application with uvicorn. If the ``PORT`` environment
        variable is set, requests will be served on that port automatically to all
        known hosts.

        :param address: The address to bind to.
        :param port: The port to bind to. If none is provided, one will be selected at random.
        :param debug: Run uvicorn server in debug mode.
        :param options: Additional keyword arguments to send to ``uvicorn.run()``.
        """

        if "PORT" in os.environ:
            if address is None:
                address = "0.0.0.0"
            port = int(os.environ["PORT"])

        if address is None:
            address = "127.0.0.1"
        if port is None:
            port = 5042

        def spawn():
            uvicorn.run(self, host=address, port=port, debug=debug, **options)

        spawn()
コード例 #7
0
class Starlette:
    def __init__(self,
                 debug: bool = False,
                 template_directory: str = None) -> None:
        self._debug = debug
        self.router = Router()
        self.lifespan_handler = LifespanHandler()
        self.exception_middleware = ExceptionMiddleware(self.router,
                                                        debug=debug)
        self.error_middleware = ServerErrorMiddleware(
            self.exception_middleware, debug=debug)
        self.schema_generator = None  # type: typing.Optional[BaseSchemaGenerator]
        self.template_env = self.load_template_env(template_directory)

    @property
    def routes(self) -> typing.List[BaseRoute]:
        return self.router.routes

    @property
    def debug(self) -> bool:
        return self._debug

    @debug.setter
    def debug(self, value: bool) -> None:
        self._debug = value
        self.exception_middleware.debug = value
        self.error_middleware.debug = value

    def load_template_env(self, template_directory: str = None) -> typing.Any:
        if template_directory is None:
            return None

        # Import jinja2 lazily.
        import jinja2

        @jinja2.contextfunction
        def url_for(context: dict, name: str,
                    **path_params: typing.Any) -> str:
            request = context["request"]
            return request.url_for(name, **path_params)

        loader = jinja2.FileSystemLoader(str(template_directory))
        env = jinja2.Environment(loader=loader, autoescape=True)
        env.globals["url_for"] = url_for
        return env

    def get_template(self, name: str) -> typing.Any:
        return self.template_env.get_template(name)

    @property
    def schema(self) -> dict:
        assert self.schema_generator is not None
        return self.schema_generator.get_schema(self.routes)

    def on_event(self, event_type: str) -> typing.Callable:
        return self.lifespan_handler.on_event(event_type)

    def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
        self.router.mount(path, app=app, name=name)

    def add_middleware(self, middleware_class: type,
                       **kwargs: typing.Any) -> None:
        self.error_middleware.app = middleware_class(self.error_middleware.app,
                                                     **kwargs)

    def add_exception_handler(
        self,
        exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
        handler: typing.Callable,
    ) -> None:
        if exc_class_or_status_code in (500, Exception):
            self.error_middleware.handler = handler
        else:
            self.exception_middleware.add_exception_handler(
                exc_class_or_status_code, handler)

    def add_event_handler(self, event_type: str,
                          func: typing.Callable) -> None:
        self.lifespan_handler.add_event_handler(event_type, func)

    def add_route(
        self,
        path: str,
        route: typing.Callable,
        methods: typing.List[str] = None,
        name: str = None,
        include_in_schema: bool = True,
    ) -> None:
        self.router.add_route(path, route, methods=methods, name=name)

    def add_websocket_route(self,
                            path: str,
                            route: typing.Callable,
                            name: str = None) -> None:
        self.router.add_websocket_route(path, route, name=name)

    def exception_handler(
        self, exc_class_or_status_code: typing.Union[int,
                                                     typing.Type[Exception]]
    ) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.add_exception_handler(exc_class_or_status_code, func)
            return func

        return decorator

    def route(
        self,
        path: str,
        methods: typing.List[str] = None,
        name: str = None,
        include_in_schema: bool = True,
    ) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.router.add_route(
                path,
                func,
                methods=methods,
                name=name,
                include_in_schema=include_in_schema,
            )
            return func

        return decorator

    def websocket_route(self, path: str, name: str = None) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.router.add_websocket_route(path, func, name=name)
            return func

        return decorator

    def middleware(self, middleware_type: str) -> typing.Callable:
        assert (middleware_type == "http"
                ), 'Currently only middleware("http") is supported.'

        def decorator(func: typing.Callable) -> typing.Callable:
            self.add_middleware(BaseHTTPMiddleware, dispatch=func)
            return func

        return decorator

    def url_path_for(self, name: str, **path_params: str) -> URLPath:
        return self.router.url_path_for(name, **path_params)

    def __call__(self, scope: Scope) -> ASGIInstance:
        scope["app"] = self
        if scope["type"] == "lifespan":
            return self.lifespan_handler(scope)
        return self.error_middleware(scope)
コード例 #8
0
class Starlette:
    def __init__(self, debug: bool = False) -> None:
        self.router = Router(routes=[])
        self.lifespan_handler = LifespanHandler()
        self.app = self.router
        self.exception_middleware = ExceptionMiddleware(self.router,
                                                        debug=debug)
        self.executor = ThreadPoolExecutor()

    @property
    def debug(self) -> bool:
        return self.exception_middleware.debug

    @debug.setter
    def debug(self, value: bool) -> None:
        self.exception_middleware.debug = value

    def on_event(self, event_type: str) -> typing.Callable:
        return self.lifespan_handler.on_event(event_type)

    def mount(self,
              path: str,
              app: ASGIApp,
              methods: typing.Sequence[str] = None) -> None:
        prefix = PathPrefix(path, app=app, methods=methods)
        self.router.routes.append(prefix)

    def add_middleware(self, middleware_class: type,
                       **kwargs: typing.Any) -> None:
        self.exception_middleware.app = middleware_class(self.app, **kwargs)

    def add_exception_handler(self, exc_class: type,
                              handler: typing.Callable) -> None:
        self.exception_middleware.add_exception_handler(exc_class, handler)

    def add_route(self,
                  path: str,
                  route: typing.Callable,
                  methods: typing.Sequence[str] = None) -> None:
        if not inspect.isclass(route):
            route = request_response(route)
            if methods is None:
                methods = ("GET", )

        instance = Path(path, route, protocol="http", methods=methods)
        self.router.routes.append(instance)

    def add_websocket_route(self, path: str, route: typing.Callable) -> None:
        if not inspect.isclass(route):
            route = websocket_session(route)

        instance = Path(path, route, protocol="websocket")
        self.router.routes.append(instance)

    def exception_handler(self, exc_class: type) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.add_exception_handler(exc_class, func)
            return func

        return decorator

    def route(self,
              path: str,
              methods: typing.Sequence[str] = None) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.add_route(path, func, methods=methods)
            return func

        return decorator

    def websocket_route(self, path: str) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.add_websocket_route(path, func)
            return func

        return decorator

    def __call__(self, scope: Scope) -> ASGIInstance:
        scope["app"] = self
        scope["executor"] = self.executor
        if scope["type"] == "lifespan":
            return self.lifespan_handler(scope)
        return self.exception_middleware(scope)
コード例 #9
0
class Starlette:
    def __init__(self, debug: bool = False) -> None:
        self.router = Router()
        self.lifespan_handler = LifespanHandler()
        self.app = self.router
        self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)

    @property
    def routes(self) -> typing.List[BaseRoute]:
        return self.router.routes

    @property
    def debug(self) -> bool:
        return self.exception_middleware.debug

    @debug.setter
    def debug(self, value: bool) -> None:
        self.exception_middleware.debug = value

    def on_event(self, event_type: str) -> typing.Callable:
        return self.lifespan_handler.on_event(event_type)

    def mount(self, path: str, app: ASGIApp) -> None:
        self.router.mount(path, app=app)

    def add_middleware(self, middleware_class: type, **kwargs: typing.Any) -> None:
        self.exception_middleware.app = middleware_class(self.app, **kwargs)

    def add_exception_handler(self, exc_class: type, handler: typing.Callable) -> None:
        self.exception_middleware.add_exception_handler(exc_class, handler)

    def add_event_handler(self, event_type: str, func: typing.Callable) -> None:
        self.lifespan_handler.add_event_handler(event_type, func)

    def add_route(
        self, path: str, route: typing.Callable, methods: typing.List[str] = None
    ) -> None:
        self.router.add_route(path, route, methods=methods)

    def add_graphql_route(self, path: str, schema: typing.Any) -> None:
        self.router.add_graphql_route(path, schema)

    def add_websocket_route(self, path: str, route: typing.Callable) -> None:
        self.router.add_websocket_route(path, route)

    def exception_handler(self, exc_class: type) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.add_exception_handler(exc_class, func)
            return func

        return decorator

    def route(self, path: str, methods: typing.List[str] = None) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.router.add_route(path, func, methods=methods)
            return func

        return decorator

    def websocket_route(self, path: str) -> typing.Callable:
        def decorator(func: typing.Callable) -> typing.Callable:
            self.router.add_websocket_route(path, func)
            return func

        return decorator

    def url_path_for(self, name: str, **path_params: str) -> URL:
        return self.router.url_path_for(name, **path_params)

    def __call__(self, scope: Scope) -> ASGIInstance:
        scope["app"] = self
        scope["router"] = self.router
        if scope["type"] == "lifespan":
            return self.lifespan_handler(scope)
        return self.exception_middleware(scope)