def lambda_handler(self, event, context):
        loop = asyncio.get_event_loop()
        lifespan = Lifespan(self._asgi)
        loop.create_task(lifespan.run())
        loop.run_until_complete(lifespan.wait_startup())

        connection_scope = self.get_connection_scope(event=event,
                                                     context=context)

        async def _receive() -> Message:
            body = event['body']
            if event['isBase64Encoded']:
                body = base64.standard_b64decode(body)
            return {'type': 'http.request', 'body': body, 'more_body': False}

        response = {}

        async def _send(message: Message) -> None:
            if message['type'] == 'http.response.start':
                response["statusCode"] = message['status']
                response["isBase64Encoded"] = False
                response["headers"] = {
                    k.decode('utf-8'): v.decode('utf-8')
                    for k, v in message['headers']
                }
            if message['type'] == 'http.response.body':
                response["body"] = message['body'].decode('utf-8')

        asgi = self._asgi(connection_scope)
        loop.run_until_complete(asgi(_receive, _send))
        loop.run_until_complete(lifespan.wait_shutdown())

        return response
Beispiel #2
0
class LifespanContext:
    def __init__(self, app):
        self.app = app
        self.lifespan = Lifespan(app)

    async def __aenter__(self):
        loop = asyncio.get_event_loop()
        loop.create_task(self.lifespan.run())
        await self.lifespan.wait_startup()
        return self.app

    async def __aexit__(self, exc_type, exc, tb):
        await self.lifespan.wait_shutdown()
Beispiel #3
0
    def lambda_handler(self, event, context):
        loop = asyncio.get_event_loop()
        lifespan = Lifespan(self._asgi)
        loop.create_task(lifespan.run())
        loop.run_until_complete(lifespan.wait_startup())

        connection_scope = {
            'type': 'http',
            'http_version': '1.1',
            'scheme': 'http',
            'method': event['httpMethod'],
            'root_path': '',
            'path': event['path'],
            'query_string':
            urllib.parse.urlencode(event['queryStringParameters']),
            'headers': event['headers'].items(),
            'x-aws-lambda': {
                'requestContext': event['requestContext'],
                'lambdaContext': context
            }
        }

        async def _receive() -> Message:
            body = event['body']
            if event['isBase64Encoded']:
                body = base64.standard_b64decode(body)
            return {'type': 'http.request', 'body': body, 'more_body': False}

        response = {}

        async def _send(message: Message) -> None:
            if message['type'] == 'http.response.start':
                response["statusCode"] = message['status']
                response["isBase64Encoded"] = False
                response["headers"] = {
                    k.decode('utf-8'): v.decode('utf-8')
                    for k, v in message['headers']
                }
            if message['type'] == 'http.response.body':
                response["body"] = message['body'].decode('utf-8')

        asgi = self._asgi(connection_scope)
        loop.run_until_complete(asgi(_receive, _send))
        loop.run_until_complete(lifespan.wait_shutdown())

        return response
Beispiel #4
0
class Server:
    def __init__(
        self,
        app,
        host,
        port,
        uds,
        sock,
        logger,
        loop,
        connections,
        tasks,
        state,
        limit_max_requests,
        create_protocol,
        on_tick,
        install_signal_handlers,
        ready_event,
    ):
        self.app = app
        self.host = host
        self.port = port
        self.uds = uds
        self.sock = sock
        self.logger = logger
        self.loop = loop
        self.connections = connections
        self.tasks = tasks
        self.state = state
        self.limit_max_requests = limit_max_requests
        self.create_protocol = create_protocol
        self.on_tick = on_tick
        self.install_signal_handlers = install_signal_handlers
        self.ready_event = ready_event
        self.should_exit = False
        self.pid = os.getpid()

    def set_signal_handlers(self):
        if not self.install_signal_handlers:
            return

        try:
            for sig in HANDLED_SIGNALS:
                self.loop.add_signal_handler(sig, self.handle_exit, sig, None)
        except NotImplementedError:
            # Windows
            for sig in HANDLED_SIGNALS:
                signal.signal(sig, self.handle_exit)

    def handle_exit(self, sig, frame):
        self.should_exit = True

    def run(self):
        self.logger.info("Started server process [{}]".format(self.pid))
        self.set_signal_handlers()
        self.lifespan = Lifespan(self.app, self.logger)
        if self.lifespan.is_enabled:
            self.logger.info("Waiting for application startup.")
            self.loop.create_task(self.lifespan.run())
            self.loop.run_until_complete(self.lifespan.wait_startup())
        else:
            self.logger.debug("Lifespan protocol is not recognized by the application")
        self.loop.run_until_complete(self.create_server())
        self.loop.create_task(self.tick())
        if self.ready_event is not None:
            self.ready_event.set()
        self.loop.run_forever()

    async def create_server(self):
        if self.sock is not None:
            # Use an existing socket.
            self.server = await self.loop.create_server(
                self.create_protocol, sock=self.sock
            )
            message = "Uvicorn running on socket %s (Press CTRL+C to quit)"
            self.logger.info(message % str(self.sock.getsockname()))

        elif self.uds is not None:
            # Create a socket using UNIX domain socket.
            self.server = await self.loop.create_unix_server(
                self.create_protocol, path=self.uds
            )
            message = "Uvicorn running on unix socket %s (Press CTRL+C to quit)"
            self.logger.info(message % self.uds)

        else:
            # Standard case. Create a socket from a host/port pair.
            self.server = await self.loop.create_server(
                self.create_protocol, host=self.host, port=self.port
            )
            message = "Uvicorn running on http://%s:%d (Press CTRL+C to quit)"
            self.logger.info(message % (self.host, self.port))

    async def tick(self):
        should_limit_requests = self.limit_max_requests is not None

        while not self.should_exit:
            if (
                should_limit_requests
                and self.state["total_requests"] >= self.limit_max_requests
            ):
                break
            self.on_tick()
            await asyncio.sleep(1)

        self.logger.info("Stopping server process [{}]".format(self.pid))
        self.server.close()
        await self.server.wait_closed()
        for connection in list(self.connections):
            connection.shutdown()

        await asyncio.sleep(0.1)
        if self.connections:
            self.logger.info("Waiting for connections to close.")
            while self.connections:
                await asyncio.sleep(0.1)
        if self.tasks:
            self.logger.info("Waiting for background tasks to complete.")
            while self.tasks:
                await asyncio.sleep(0.1)

        if self.lifespan.is_enabled:
            self.logger.info("Waiting for application cleanup.")
            await self.lifespan.wait_shutdown()

        self.loop.stop()
Beispiel #5
0
class UvicornWorker(Worker):
    """
    A worker class for Gunicorn that interfaces with an ASGI consumer callable,
    rather than a WSGI callable.

    We use a couple of packages from MagicStack in order to achieve an
    extremely high-throughput and low-latency implementation:

    * `uvloop` as the event loop policy.
    * `httptools` as the HTTP request parser.
    """

    protocol_class = HttpToolsProtocol
    ws_protocol_class = WebSocketProtocol
    loop = "uvloop"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.servers = []
        self.exit_code = 0
        self.log.level = self.log.loglevel

    def init_process(self):
        if self.loop == "uvloop":
            # Close any existing event loop before setting a
            # new policy.
            asyncio.get_event_loop().close()

            # Setup uvloop policy, so that every
            # asyncio.get_event_loop() will create an instance
            # of uvloop event loop.
            asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

        super().init_process()

    def run(self):
        app = self.wsgi

        if self.log.level <= logging.DEBUG:
            app = MessageLoggerMiddleware(app)

        loop = asyncio.get_event_loop()

        self.lifespan = Lifespan(app, self.log)
        if self.lifespan.is_enabled:
            loop.create_task(self.lifespan.run())
            loop.run_until_complete(self.lifespan.wait_startup())
        else:
            self.log.debug(
                "Lifespan protocol is not recognized by the application")

        loop.create_task(self.create_servers(loop, app))
        loop.create_task(self.tick(loop))
        loop.run_forever()
        sys.exit(self.exit_code)

    def init_signals(self):
        # Set up signals through the event loop API.
        loop = asyncio.get_event_loop()

        loop.add_signal_handler(signal.SIGQUIT, self.handle_quit,
                                signal.SIGQUIT, None)

        loop.add_signal_handler(signal.SIGTERM, self.handle_exit,
                                signal.SIGTERM, None)

        loop.add_signal_handler(signal.SIGINT, self.handle_quit, signal.SIGINT,
                                None)

        loop.add_signal_handler(signal.SIGWINCH, self.handle_winch,
                                signal.SIGWINCH, None)

        loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1,
                                signal.SIGUSR1, None)

        loop.add_signal_handler(signal.SIGABRT, self.handle_abort,
                                signal.SIGABRT, None)

        # Don't let SIGTERM and SIGUSR1 disturb active requests
        # by interrupting system calls
        signal.siginterrupt(signal.SIGTERM, False)
        signal.siginterrupt(signal.SIGUSR1, False)

    def handle_quit(self, sig, frame):
        self.alive = False
        self.cfg.worker_int(self)

    def handle_abort(self, sig, frame):
        self.alive = False
        self.exit_code = 1
        self.cfg.worker_abort(self)

    async def create_servers(self, loop, app):
        cfg = self.cfg

        ssl_ctx = self.create_ssl_context(
            self.cfg) if self.cfg.is_ssl else None

        for sock in self.sockets:
            state = {"total_requests": 0}
            connections = set()
            protocol = functools.partial(self.protocol_class,
                                         app=app,
                                         loop=loop,
                                         connections=connections,
                                         state=state,
                                         logger=self.log,
                                         ws_protocol_class=WebSocketProtocol,
                                         timeout_keep_alive=self.cfg.keepalive)
            server = await loop.create_server(protocol, sock=sock, ssl=ssl_ctx)
            self.servers.append((server, state))

    def create_ssl_context(self, cfg):
        ctx = ssl.SSLContext(cfg.ssl_version)
        ctx.load_cert_chain(cfg.certfile, cfg.keyfile)
        ctx.verify_mode = cfg.cert_reqs
        if cfg.ca_certs:
            ctx.load_verify_locations(cfg.ca_certs)
        if cfg.ciphers:
            ctx.set_ciphers(cfg.ciphers)
        return ctx

    async def tick(self, loop):
        pid = os.getpid()
        cycle = 0

        while self.alive:
            self.protocol_class.tick()

            cycle = (cycle + 1) % 10
            if cycle == 0:
                self.notify()

            req_count = sum(
                [state["total_requests"] for server, state in self.servers])
            if self.max_requests and req_count > self.max_requests:
                self.alive = False
                self.log.info("Max requests exceeded, shutting down: %s", self)
            elif self.ppid != os.getppid():
                self.alive = False
                self.log.info("Parent changed, shutting down: %s", self)
            else:
                await asyncio.sleep(1)

        for server, state in self.servers:
            server.close()
            await server.wait_closed()

        if self.lifespan.is_enabled:
            await self.lifespan.wait_cleanup()

        loop.stop()
Beispiel #6
0
class Server:
    def __init__(self, config, global_state):
        self.config = config
        self.global_state = global_state

        self.app = config.app
        self.loop = config.loop
        self.logger = config.logger
        self.limit_max_requests = config.limit_max_requests
        self.disable_lifespan = config.disable_lifespan
        self.on_tick = config.http_protocol_class.tick
        self.should_exit = False
        self.force_exit = False
        self.pid = os.getpid()

        def create_protocol():
            return config.http_protocol_class(config=config,
                                              global_state=global_state)

        self.create_protocol = create_protocol

    def set_signal_handlers(self):
        if not self.config.install_signal_handlers:
            return

        try:
            for sig in HANDLED_SIGNALS:
                self.loop.add_signal_handler(sig, self.handle_exit, sig, None)
        except NotImplementedError as exc:
            # Windows
            for sig in HANDLED_SIGNALS:
                signal.signal(sig, self.handle_exit)

    def handle_exit(self, sig, frame):
        if self.should_exit:
            self.force_exit = True
        else:
            self.should_exit = True

    def run(self):
        self.logger.info("Started server process [{}]".format(self.pid))
        self.set_signal_handlers()
        if not self.disable_lifespan:
            self.lifespan = Lifespan(self.app, self.logger)
            if self.lifespan.is_enabled:
                self.logger.info("Waiting for application startup.")
                self.loop.create_task(self.lifespan.run())
                self.loop.run_until_complete(self.lifespan.wait_startup())
                if self.lifespan.error_occured:
                    self.logger.error("Application startup failed. Exiting.")
                    return
            else:
                self.logger.debug(
                    "Lifespan protocol is not recognized by the application")
        self.loop.run_until_complete(self.create_server())
        self.loop.create_task(self.tick())
        self.global_state.started.set()
        self.loop.run_forever()

    async def create_server(self):
        config = self.config

        if config.sock is not None:
            # Use an existing socket.
            self.server = await self.loop.create_server(self.create_protocol,
                                                        sock=config.sock)
            message = "Uvicorn running on socket %s (Press CTRL+C to quit)"
            self.logger.info(message % str(config.sock.getsockname()))

        elif config.uds is not None:
            # Create a socket using UNIX domain socket.
            self.server = await self.loop.create_unix_server(
                self.create_protocol, path=config.uds)
            message = "Uvicorn running on unix socket %s (Press CTRL+C to quit)"
            self.logger.info(message % config.uds)

        else:
            # Standard case. Create a socket from a host/port pair.
            self.server = await self.loop.create_server(self.create_protocol,
                                                        host=config.host,
                                                        port=config.port)
            message = "Uvicorn running on http://%s:%d (Press CTRL+C to quit)"
            self.logger.info(message % (config.host, config.port))

    async def tick(self):
        should_limit_requests = self.limit_max_requests is not None

        while not self.should_exit:
            if (should_limit_requests and self.global_state.total_requests >=
                    self.limit_max_requests):
                break
            self.on_tick()
            await asyncio.sleep(1)

        self.logger.info("Stopping server process [{}]".format(self.pid))
        self.server.close()
        await self.server.wait_closed()
        for connection in list(self.global_state.connections):
            connection.shutdown()

        await asyncio.sleep(0.1)
        if self.global_state.connections and not self.force_exit:
            self.logger.info(
                "Waiting for connections to close. (Press CTRL+C to force quit)"
            )
            while self.global_state.connections and not self.force_exit:
                await asyncio.sleep(0.1)
        if self.global_state.tasks and not self.force_exit:
            self.logger.info(
                "Waiting for background tasks to complete. (Press CTRL+C to force quit)"
            )
            while self.global_state.tasks and not self.force_exit:
                await asyncio.sleep(0.1)

        if not self.disable_lifespan and self.lifespan.is_enabled and not self.force_exit:
            self.logger.info("Waiting for application shutdown.")
            await self.lifespan.wait_shutdown()

        if self.force_exit:
            self.logger.info("Forced quit.")

        self.loop.stop()