示例#1
0
class Proto(asyncio.Protocol):
    def connection_made(self, transport):
        self.transport = transport
        self.parser = HttpRequestParser(self)
        self.body = None

    def connection_lost(self, exc):
        self.transport = None

    def data_received(self, data):
        self.parser.feed_data(data)

    def on_body(self, body):
        self.body = body

    def on_message_complete(self):
        request = ujson.loads(self.body)
        response = {'result': request['a'] + request['b']}
        self.send(ujson.dumps(response).encode('utf8'))
        self.body = None

    def send(self, data):
        response = b'HTTP/1.1 200 OK\nContent-Length: ' + str(
            len(data)).encode() + b'\n\n' + data
        self.transport.write(response)

        if not self.parser.should_keep_alive():
            #print('no keep-alive')
            self.transport.close()
            self.transport = None
示例#2
0
文件: server.py 项目: wmingstar/sanic
class HttpProtocol(asyncio.Protocol):
    __slots__ = (
        # event loop, connection
        'loop',
        'transport',
        'connections',
        'signal',
        # request params
        'parser',
        'request',
        'url',
        'headers',
        # request config
        'request_handler',
        'request_timeout',
        'request_max_size',
        # connection management
        '_total_request_size',
        '_timeout_handler',
        '_last_communication_time')

    def __init__(self,
                 *,
                 loop,
                 request_handler,
                 error_handler,
                 signal=Signal(),
                 connections={},
                 request_timeout=60,
                 request_max_size=None):
        self.loop = loop
        self.transport = None
        self.request = None
        self.parser = None
        self.url = None
        self.headers = None
        self.signal = signal
        self.connections = connections
        self.request_handler = request_handler
        self.error_handler = error_handler
        self.request_timeout = request_timeout
        self.request_max_size = request_max_size
        self._total_request_size = 0
        self._timeout_handler = None
        self._last_request_time = None
        self._request_handler_task = None

    # -------------------------------------------- #
    # Connection
    # -------------------------------------------- #

    def connection_made(self, transport):
        self.connections[self] = True
        self._timeout_handler = self.loop.call_later(self.request_timeout,
                                                     self.connection_timeout)
        self.transport = transport
        self._last_request_time = current_time

    def connection_lost(self, exc):
        del self.connections[self]
        self._timeout_handler.cancel()
        self.cleanup()

    def connection_timeout(self):
        # Check if
        time_elapsed = current_time - self._last_request_time
        if time_elapsed < self.request_timeout:
            time_left = self.request_timeout - time_elapsed
            self._timeout_handler = \
                self.loop.call_later(time_left, self.connection_timeout)
        else:
            if self._request_handler_task:
                self._request_handler_task.cancel()
            response = self.error_handler.response(
                self.request, RequestTimeout('Request Timeout'))
            self.write_response(response)

    # -------------------------------------------- #
    # Parsing
    # -------------------------------------------- #

    def data_received(self, data):
        # Check for the request itself getting too large and exceeding
        # memory limits
        self._total_request_size += len(data)
        if self._total_request_size > self.request_max_size:
            return self.bail_out(
                "Request too large ({}), connection closed".format(
                    self._total_request_size))

        # Create parser if this is the first time we're receiving data
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = HttpRequestParser(self)

        # Parse request chunk or close connection
        try:
            self.parser.feed_data(data)
        except HttpParserError as e:
            self.bail_out(
                "Invalid request data, connection closed ({})".format(e))

    def on_url(self, url):
        self.url = url

    def on_header(self, name, value):
        if name == b'Content-Length' and int(value) > self.request_max_size:
            return self.bail_out(
                "Request body too large ({}), connection closed".format(value))

        self.headers.append((name.decode(), value.decode('utf-8')))

    def on_headers_complete(self):
        remote_addr = self.transport.get_extra_info('peername')
        if remote_addr:
            self.headers.append(('Remote-Addr', '%s:%s' % remote_addr))

        self.request = Request(url_bytes=self.url,
                               headers=CIMultiDict(self.headers),
                               version=self.parser.get_http_version(),
                               method=self.parser.get_method().decode())

    def on_body(self, body):
        if self.request.body:
            self.request.body += body
        else:
            self.request.body = body

    def on_message_complete(self):
        self._request_handler_task = self.loop.create_task(
            self.request_handler(self.request, self.write_response))

    # -------------------------------------------- #
    # Responding
    # -------------------------------------------- #

    def write_response(self, response):
        try:
            keep_alive = self.parser.should_keep_alive() \
                            and not self.signal.stopped
            self.transport.write(
                response.output(self.request.version, keep_alive,
                                self.request_timeout))
            if not keep_alive:
                self.transport.close()
            else:
                # Record that we received data
                self._last_request_time = current_time
                self.cleanup()
        except Exception as e:
            self.bail_out(
                "Writing response failed, connection closed {}".format(e))

    def bail_out(self, message):
        log.debug(message)
        self.transport.close()

    def cleanup(self):
        self.parser = None
        self.request = None
        self.url = None
        self.headers = None
        self._request_handler_task = None
        self._total_request_size = 0

    def close_if_idle(self):
        """
        Close the connection if a request is not being sent or received
        :return: boolean - True if closed, false if staying open
        """
        if not self.parser:
            self.transport.close()
            return True
        return False
示例#3
0
文件: server.py 项目: banjocat/sanic
class HttpProtocol(asyncio.Protocol):
    __slots__ = (
        # event loop, connection
        'loop', 'transport', 'connections', 'signal',
        # request params
        'parser', 'request', 'url', 'headers',
        # request config
        'request_handler', 'request_timeout', 'request_max_size',
        # connection management
        '_total_request_size', '_timeout_handler', '_last_communication_time')

    def __init__(self, *, loop, request_handler, error_handler,
                 signal=Signal(), connections=set(), request_timeout=60,
                 request_max_size=None):
        self.loop = loop
        self.transport = None
        self.request = None
        self.parser = None
        self.url = None
        self.headers = None
        self.signal = signal
        self.connections = connections
        self.request_handler = request_handler
        self.error_handler = error_handler
        self.request_timeout = request_timeout
        self.request_max_size = request_max_size
        self._total_request_size = 0
        self._timeout_handler = None
        self._last_request_time = None
        self._request_handler_task = None

    # -------------------------------------------- #
    # Connection
    # -------------------------------------------- #

    def connection_made(self, transport):
        self.connections.add(self)
        self._timeout_handler = self.loop.call_later(
            self.request_timeout, self.connection_timeout)
        self.transport = transport
        self._last_request_time = current_time

    def connection_lost(self, exc):
        self.connections.discard(self)
        self._timeout_handler.cancel()

    def connection_timeout(self):
        # Check if
        time_elapsed = current_time - self._last_request_time
        if time_elapsed < self.request_timeout:
            time_left = self.request_timeout - time_elapsed
            self._timeout_handler = (
                self.loop.call_later(time_left, self.connection_timeout))
        else:
            if self._request_handler_task:
                self._request_handler_task.cancel()
            exception = RequestTimeout('Request Timeout')
            self.write_error(exception)

    # -------------------------------------------- #
    # Parsing
    # -------------------------------------------- #

    def data_received(self, data):
        # Check for the request itself getting too large and exceeding
        # memory limits
        self._total_request_size += len(data)
        if self._total_request_size > self.request_max_size:
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)

        # Create parser if this is the first time we're receiving data
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = HttpRequestParser(self)

        # Parse request chunk or close connection
        try:
            self.parser.feed_data(data)
        except HttpParserError:
            exception = InvalidUsage('Bad Request')
            self.write_error(exception)

    def on_url(self, url):
        self.url = url

    def on_header(self, name, value):
        if name == b'Content-Length' and int(value) > self.request_max_size:
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)

        self.headers.append((name.decode().casefold(), value.decode()))

    def on_headers_complete(self):
        self.request = Request(
            url_bytes=self.url,
            headers=CIDict(self.headers),
            version=self.parser.get_http_version(),
            method=self.parser.get_method().decode(),
            transport=self.transport
        )

    def on_body(self, body):
        self.request.body.append(body)

    def on_message_complete(self):
        if self.request.body:
            self.request.body = b''.join(self.request.body)

        self._request_handler_task = self.loop.create_task(
            self.request_handler(
                self.request,
                self.write_response,
                self.stream_response))

    # -------------------------------------------- #
    # Responding
    # -------------------------------------------- #
    def write_response(self, response):
        """
        Writes response content synchronously to the transport.
        """
        try:
            keep_alive = (
                self.parser.should_keep_alive() and not self.signal.stopped)

            self.transport.write(
                response.output(
                    self.request.version, keep_alive,
                    self.request_timeout))
        except AttributeError:
            log.error(
                ('Invalid response object for url {}, '
                 'Expected Type: HTTPResponse, Actual Type: {}').format(
                    self.url, type(response)))
            self.write_error(ServerError('Invalid response type'))
        except RuntimeError:
            log.error(
                'Connection lost before response written @ {}'.format(
                    self.request.ip))
        except Exception as e:
            self.bail_out(
                "Writing response failed, connection closed {}".format(
                    repr(e)))
        finally:
            if not keep_alive:
                self.transport.close()
            else:
                self._last_request_time = current_time
                self.cleanup()

    async def stream_response(self, response):
        """
        Streams a response to the client asynchronously. Attaches
        the transport to the response so the response consumer can
        write to the response as needed.
        """

        try:
            keep_alive = (
                self.parser.should_keep_alive() and not self.signal.stopped)

            response.transport = self.transport
            await response.stream(
                self.request.version, keep_alive, self.request_timeout)
        except AttributeError:
            log.error(
                ('Invalid response object for url {}, '
                 'Expected Type: HTTPResponse, Actual Type: {}').format(
                    self.url, type(response)))
            self.write_error(ServerError('Invalid response type'))
        except RuntimeError:
            log.error(
                'Connection lost before response written @ {}'.format(
                    self.request.ip))
        except Exception as e:
            self.bail_out(
                "Writing response failed, connection closed {}".format(
                    repr(e)))
        finally:
            if not keep_alive:
                self.transport.close()
            else:
                self._last_request_time = current_time
                self.cleanup()

    def write_error(self, exception):
        try:
            response = self.error_handler.response(self.request, exception)
            version = self.request.version if self.request else '1.1'
            self.transport.write(response.output(version))
        except RuntimeError:
            log.error(
                'Connection lost before error written @ {}'.format(
                    self.request.ip if self.request else 'Unknown'))
        except Exception as e:
            self.bail_out(
                "Writing error failed, connection closed {}".format(repr(e)),
                from_error=True)
        finally:
            self.transport.close()

    def bail_out(self, message, from_error=False):
        if from_error or self.transport.is_closing():
            log.error(
                ("Transport closed @ {} and exception "
                 "experienced during error handling").format(
                    self.transport.get_extra_info('peername')))
            log.debug(
                'Exception:\n{}'.format(traceback.format_exc()))
        else:
            exception = ServerError(message)
            self.write_error(exception)
            log.error(message)

    def cleanup(self):
        self.parser = None
        self.request = None
        self.url = None
        self.headers = None
        self._request_handler_task = None
        self._total_request_size = 0

    def close_if_idle(self):
        """Close the connection if a request is not being sent or received

        :return: boolean - True if closed, false if staying open
        """
        if not self.parser:
            self.transport.close()
            return True
        return False
示例#4
0
class HttpProtocol(asyncio.Protocol):
    __slots__ = (
        # event loop, connection
        'loop',
        'transport',
        'connections',
        'signal',
        # request params
        'parser',
        'request',
        'url',
        'headers',
        # request config
        'request_handler',
        'request_timeout',
        'request_max_size',
        # connection management
        '_total_request_size',
        '_timeout_handler',
        '_last_communication_time')

    def __init__(self,
                 *,
                 loop,
                 request_handler,
                 error_handler,
                 signal=Signal(),
                 connections={},
                 request_timeout=60,
                 request_max_size=None):
        self.loop = loop
        self.transport = None
        self.request = None  # 请求
        self.parser = None
        self.url = None
        self.headers = None  # 请求头
        self.signal = signal
        self.connections = connections
        self.request_handler = request_handler  # 请求处理器
        self.error_handler = error_handler  # 出错处理器
        self.request_timeout = request_timeout
        self.request_max_size = request_max_size
        self._total_request_size = 0
        self._timeout_handler = None
        self._last_request_time = None
        self._request_handler_task = None

    # -------------------------------------------- #
    # Connection
    # -------------------------------------------- #

    def connection_made(self, transport):
        self.connections.add(self)
        self._timeout_handler = self.loop.call_later(self.request_timeout,
                                                     self.connection_timeout)
        self.transport = transport
        self._last_request_time = current_time

    def connection_lost(self, exc):
        self.connections.discard(self)
        self._timeout_handler.cancel()
        self.cleanup()

    def connection_timeout(self):
        # Check if
        time_elapsed = current_time - self._last_request_time
        if time_elapsed < self.request_timeout:
            time_left = self.request_timeout - time_elapsed
            self._timeout_handler = \
                self.loop.call_later(time_left, self.connection_timeout)
        else:
            if self._request_handler_task:
                self._request_handler_task.cancel()
            exception = RequestTimeout('Request Timeout')
            self.write_error(exception)

    # -------------------------------------------- #
    # Parsing
    # -------------------------------------------- #

    def data_received(self, data):
        # Check for the request itself getting too large and exceeding
        # memory limits
        self._total_request_size += len(data)
        if self._total_request_size > self.request_max_size:
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)

        # Create parser if this is the first time we're receiving data
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = HttpRequestParser(self)

        # Parse request chunk or close connection
        try:
            self.parser.feed_data(data)
        except HttpParserError:
            exception = InvalidUsage('Bad Request')
            self.write_error(exception)

    def on_url(self, url):
        self.url = url

    #
    # HTTP 请求: 补全 head 信息
    #   -  更新 headers 字段
    #
    def on_header(self, name, value):
        if name == b'Content-Length' and int(value) > self.request_max_size:
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)

        self.headers.append((name.decode(), value.decode('utf-8')))

    #
    # HTTP 请求: 写入 head 信息
    #
    def on_headers_complete(self):
        remote_addr = self.transport.get_extra_info('peername')
        if remote_addr:
            self.headers.append(('Remote-Addr', '%s:%s' % remote_addr))

        #
        # 构建 HTTP 请求
        #
        self.request = Request(url_bytes=self.url,
                               headers=CIMultiDict(self.headers),
                               version=self.parser.get_http_version(),
                               method=self.parser.get_method().decode())

    #
    # HTTP 请求: 写入 body 部分
    #
    def on_body(self, body):
        if self.request.body:
            self.request.body += body
        else:
            self.request.body = body

    def on_message_complete(self):
        #
        # 任务创建:
        #
        self._request_handler_task = self.loop.create_task(
            self.request_handler(self.request, self.write_response))

    # -------------------------------------------- #
    # Responding
    #   -  HTTP 响应部分
    # -------------------------------------------- #

    #
    # HTTP 响应: 正常响应
    #   - 写出 HTTP 响应
    #   - 长连接, 更新连接时间
    #
    def write_response(self, response):
        try:
            keep_alive = self.parser.should_keep_alive(
            ) and not self.signal.stopped
            #
            # 输出 HTTP 响应
            #
            self.transport.write(
                response.output(  # HTTP Response, 写一个响应
                    self.request.version, keep_alive, self.request_timeout))

            if not keep_alive:  # 非长连接, 关闭
                self.transport.close()
            else:
                # Record that we received data
                self._last_request_time = current_time
                self.cleanup()
        except Exception as e:
            self.bail_out(
                "Writing response failed, connection closed {}".format(e))

    #
    # HTTP 响应: 出错响应
    #
    def write_error(self, exception):
        try:
            response = self.error_handler.response(self.request,
                                                   exception)  # 出错响应处理
            version = self.request.version if self.request else '1.1'  # HTTP 协议版本
            self.transport.write(
                response.output(version))  # HTTP Response, 写一个响应
            self.transport.close()
        except Exception as e:
            self.bail_out(
                "Writing error failed, connection closed {}".format(e))

    #
    # 异常记录:
    #
    def bail_out(self, message):
        exception = ServerError(message)
        self.write_error(exception)
        log.error(message)

    #
    # 清理:
    #   - 将字段复位为空
    #
    def cleanup(self):
        self.parser = None
        self.request = None
        self.url = None
        self.headers = None
        self._request_handler_task = None
        self._total_request_size = 0

    def close_if_idle(self):
        """
        Close the connection if a request is not being sent or received
        :return: boolean - True if closed, false if staying open
        """
        if not self.parser:
            self.transport.close()
            return True
        return False
示例#5
0
class Protocol(asyncio.Protocol):
    """Responsible of parsing the request and writing the response.

    You can subclass it to set your own `Query`, `Request` or `Response`
    classes.
    """

    __slots__ = ('app', 'req', 'parser', 'resp', 'writer')
    Query = Query
    Request = Request
    Response = Response

    def __init__(self, app):
        self.app = app
        self.parser = HttpRequestParser(self)

    def data_received(self, data: bytes):
        try:
            self.parser.feed_data(data)
        except HttpParserError:
            # If the parsing failed before on_message_begin, we don't have a
            # response.
            self.response = Response()
            self.response.status = HTTPStatus.BAD_REQUEST
            self.response.body = b'Unparsable request'
            self.write()

    def connection_made(self, transport):
        self.writer = transport

    # All on_xxx methods are in use by httptools parser.
    # See https://github.com/MagicStack/httptools#apis
    def on_header(self, name: bytes, value: bytes):
        self.request.headers[name.decode()] = value.decode()

    def on_body(self, body: bytes):
        self.request.body += body

    def on_url(self, url: bytes):
        self.request.url = url
        parsed = parse_url(url)
        self.request.path = parsed.path.decode()
        self.request.query_string = (parsed.query or b'').decode()
        parsed_qs = parse_qs(self.request.query_string, keep_blank_values=True)
        self.request.query = self.Query(parsed_qs)

    def on_message_begin(self):
        self.request = self.Request()
        self.response = self.Response()

    def on_message_complete(self):
        self.request.method = self.parser.get_method().decode().upper()
        task = self.app.loop.create_task(self.app(self.request, self.response))
        task.add_done_callback(self.write)

    # May or may not have "future" as arg.
    def write(self, *args):
        # Appends bytes for performances.
        payload = b'HTTP/1.1 %a %b\r\n' % (
            self.response.status.value, self.response.status.phrase.encode())
        if not isinstance(self.response.body, bytes):
            self.response.body = self.response.body.encode()
        if 'Content-Length' not in self.response.headers:
            length = len(self.response.body)
            self.response.headers['Content-Length'] = str(length)
        for key, value in self.response.headers.items():
            payload += b'%b: %b\r\n' % (key.encode(), str(value).encode())
        payload += b'\r\n%b' % self.response.body
        self.writer.write(payload)
        if not self.parser.should_keep_alive():
            self.writer.close()
示例#6
0
class Channel:

    __slots__ = (
        'parser',
        'request',
        'complete',
        'headers_complete',
        'socket',
        'reader',
    )

    def __init__(self, socket):
        self.complete = False
        self.headers_complete = False
        self.parser = HttpRequestParser(self)
        self.request = None
        self.socket = socket
        self.reader = self._reader()

    def data_received(self, data: bytes):
        try:
            self.parser.feed_data(data)
        except HttpParserUpgrade:
            self.request.upgrade = True
        except (HttpParserError, HttpParserInvalidMethodError) as exc:
            # We should log the exc.
            raise HTTPError(
                HTTPStatus.BAD_REQUEST, 'Unparsable request.')

    async def read(self, parse: bool=True) -> bytes:
        data = await self.socket.recv(1024)
        if data:
            if parse:
                self.data_received(data)
            return data

    async def _reader(self) -> bytes:
        while not self.complete:
            data = await self.read()
            if not data:
                break
            yield data

    async def _drainer(self) -> bytes:
        while True:
            data = await self.read(parse=False)
            if not data:
                break
            yield data

    def on_header(self, name: bytes, value: bytes):
        value = value.decode()
        if value:
            name = name.decode().title()
            if name in self.request.headers:
                self.request.headers[name] += ', {}'.format(value)
            else:
                self.request.headers[name] = value

    def on_body(self, data: bytes):
        self.request.body += data

    def on_message_begin(self):
        self.complete = False
        self.request = Request(self.socket, self.reader)

    def on_message_complete(self):
        self.complete = True

    def on_url(self, url: bytes):
        self.request.url = url
        parsed = parse_url(url)
        self.request.path = unquote(parsed.path.decode())
        self.request.query_string = (parsed.query or b'').decode()

    def on_headers_complete(self):
        self.request.keep_alive = self.parser.should_keep_alive()
        self.request.method = self.parser.get_method().decode().upper()
        self.headers_complete = True

    async def __aiter__(self):
        keep_alive = True
        while keep_alive:
            data = await self.read()
            if data is None:
                break
            if self.headers_complete:
                yield self.request
                keep_alive = self.request.keep_alive
                if keep_alive:
                    if not self.complete:
                        await self.reader.aclose()
                        # We drain if there's an uncomplete request.
                        async for _ in self._drainer():
                            pass
                    self.request = None
                    self.complete = False
                    self.headers_complete = False
                    self.reader = self._reader()
示例#7
0
文件: server.py 项目: blurrcat/sanic
class HttpProtocol(asyncio.Protocol):
    __slots__ = (
        # event loop, connection
        'loop', 'transport', 'connections', 'signal',
        # request params
        'parser', 'request', 'url', 'headers',
        # request config
        'request_handler', 'request_timeout', 'request_max_size',
        # connection management
        '_total_request_size', '_timeout_handler', '_last_communication_time')

    def __init__(self, *, loop, request_handler, error_handler,
                 signal=Signal(), connections=set(), request_timeout=60,
                 request_max_size=None):
        self.loop = loop
        self.transport = None
        self.request = None
        self.parser = None
        self.url = None
        self.headers = None
        self.signal = signal
        self.connections = connections
        self.request_handler = request_handler
        self.error_handler = error_handler
        self.request_timeout = request_timeout
        self.request_max_size = request_max_size
        self._total_request_size = 0
        self._timeout_handler = None
        self._last_request_time = None
        self._request_handler_task = None

    # -------------------------------------------- #
    # Connection
    # -------------------------------------------- #

    def connection_made(self, transport):
        self.connections.add(self)
        self._timeout_handler = self.loop.call_later(
            self.request_timeout, self.connection_timeout)
        self.transport = transport
        self._last_request_time = current_time

    def connection_lost(self, exc):
        self.connections.discard(self)
        self._timeout_handler.cancel()

    def connection_timeout(self):
        # Check if
        time_elapsed = current_time - self._last_request_time
        if time_elapsed < self.request_timeout:
            time_left = self.request_timeout - time_elapsed
            self._timeout_handler = (
                self.loop.call_later(time_left, self.connection_timeout))
        else:
            if self._request_handler_task:
                self._request_handler_task.cancel()
            exception = RequestTimeout('Request Timeout')
            self.write_error(exception)

    # -------------------------------------------- #
    # Parsing
    # -------------------------------------------- #

    def data_received(self, data):
        # Check for the request itself getting too large and exceeding
        # memory limits
        self._total_request_size += len(data)
        if self._total_request_size > self.request_max_size:
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)

        # Create parser if this is the first time we're receiving data
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = HttpRequestParser(self)

        # Parse request chunk or close connection
        try:
            self.parser.feed_data(data)
        except HttpParserError:
            exception = InvalidUsage('Bad Request')
            self.write_error(exception)

    def on_url(self, url):
        self.url = url

    def on_header(self, name, value):
        if name == b'Content-Length' and int(value) > self.request_max_size:
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)

        self.headers.append((name.decode().casefold(), value.decode()))

    def on_headers_complete(self):
        self.request = Request(
            url_bytes=self.url,
            headers=CIDict(self.headers),
            version=self.parser.get_http_version(),
            method=self.parser.get_method().decode(),
            transport=self.transport
        )

    def on_body(self, body):
        self.request.body.append(body)

    def on_message_complete(self):
        if self.request.body:
            self.request.body = b''.join(self.request.body)
        self._request_handler_task = self.loop.create_task(
            self.request_handler(self.request, self.write_response))

    # -------------------------------------------- #
    # Responding
    # -------------------------------------------- #

    def write_response(self, response):
        try:
            keep_alive = (
                self.parser.should_keep_alive() and not self.signal.stopped)
            self.transport.write(
                response.output(
                    self.request.version, keep_alive, self.request_timeout))
        except RuntimeError:
            log.error(
                'Connection lost before response written @ {}'.format(
                    self.request.ip))
        except Exception as e:
            self.bail_out(
                "Writing response failed, connection closed {}".format(e))
        finally:
            if not keep_alive:
                self.transport.close()
            else:
                # Record that we received data
                self._last_request_time = current_time
                self.cleanup()

    def write_error(self, exception):
        try:
            response = self.error_handler.response(self.request, exception)
            version = self.request.version if self.request else '1.1'
            self.transport.write(response.output(version))
        except RuntimeError:
            log.error(
                'Connection lost before error written @ {}'.format(
                    self.request.ip))
        except Exception as e:
            self.bail_out(
                "Writing error failed, connection closed {}".format(e),
                from_error=True)
        finally:
            self.transport.close()

    def bail_out(self, message, from_error=False):
        if from_error and self.transport.is_closing():
            log.error(
                ("Transport closed @ {} and exception "
                 "experienced during error handling").format(
                    self.transport.get_extra_info('peername')))
            log.debug(
                'Exception:\n{}'.format(traceback.format_exc()))
        else:
            exception = ServerError(message)
            self.write_error(exception)
            log.error(message)

    def cleanup(self):
        self.parser = None
        self.request = None
        self.url = None
        self.headers = None
        self._request_handler_task = None
        self._total_request_size = 0

    def close_if_idle(self):
        """
        Close the connection if a request is not being sent or received
        :return: boolean - True if closed, false if staying open
        """
        if not self.parser:
            self.transport.close()
            return True
        return False
示例#8
0
class BaseServer(asyncio.Protocol):
    def __init__(self, loop, requesthandle, toggle, connections=set(), request_timeout=10):
        self.loop = loop
        self.requesthandle = requesthandle
        self.toggle = toggle
        self.connections = connections
        self.request_timeout = request_timeout

        self.timehandle = None
        self.requesthandletask = None
        # 默认请求参数
        self.ip = None
        self.parse = None
        self.url = None
        self.headers = {}
        self.body = None
        self.httpversion = None
        self.method = None
        self.request = None
        self.contentlength = 0
        self.keep_alive = False
        # print("enter init")

    #######################
    # 连接建立
    def connection_made(self, transport):
        # 将本连接加入全局的连接集合
        print("connect start")
        self.connections.add(self)
        self.transport = transport
        self.ip = transport.get_extra_info("peername")
        # self.timehandle = self.loop.call_later(self.request_timeout, self.teardown)

    def connection_lost(self, exc):
        print("connect end")
        self.connections.remove(self)
        # self.timehandle.cancel()
        self.clean()

    ########################
    # 解析数据
    def data_received(self, data):
        pprint(data.decode("utf-8"))
        self.contentlength += len(data)
        if not self.parse:
            # HttpRequestParser响应当前对象的以下方法
            # - on_url(url:byte)
            # - on_header(name: bytes, value: bytes)
            # - on_headers_complete()
            # - on_body(body: bytes)
            # - on_message_complete()
            # get_http_version(self) -> str
            # def should_keep_alive(self) -> bool:
            self.parse = HttpRequestParser(self)
        try:
            self.parse.feed_data(data)
        except HttpRequestParser as e:
            pass


    def on_url(self, url):
        self.url = url

    def on_header(self, name, value):
        self.headers[name] = value

    def on_headers_complete(self):
        self.method = self.parse.get_method()
        self.httpversion = self.parse.get_http_version()
        self.keep_alive = self.parse.should_keep_alive()
        self.request = Request(self.ip, self.url, self.headers, self.method, self.httpversion,
                               self.request_timeout if self.keep_alive else None)

    def on_body(self, body):
        # print(body)
        self.body = body
        self.request.setbody(body)

    def on_message_complete(self):
        print("parser complete")
        self.requesthandletask = self.loop.create_task(self.requesthandle(self.request, self.write))

    #########################
    # 返回响应
    def write(self, response):
        print("start write")
        self.transport.write(response.make_response())
        print("end write")
        keep_alive = self.keep_alive and not self.toggle[0]
        if not keep_alive:
            self.transport.close()
        else:
            self.clean()

    ########################
    # 清理或者关闭连接
    def clean(self):
        self.requsettask = None
        self.parse = None
        self.url = None
        self.headers = {}
        self.body = None
        self.httpversion = None
        self.methods = None
        self.request = None
        self.contentlength = 0
        self.keep_alive = False

    def teardown(self):
        if not self.parse:
            self.transport.close()
            return True
        return False
示例#9
0
class HttpProtocol(asyncio.Protocol):
  # http://book.pythontips.com/en/latest/__slots__magic.html
  # def __init__(self, *, loop, request_handler, error_handler, signal, connections, request_timeout) -> None:
  def __init__(self, params: panic_datatypes.ServerParams):
    self.params = params
    self.loop = params.loop
    self.transport = None
    self.request = None
    self.parser = None
    self.url = None
    self.headers = None
    self.signal = params.signal
    self.connections = params.connections
    self.request_handler = params.request_handler
    self.request_timeout = params.request_timeout
    self._total_request_size = 0
    self._timeout_handler = None
    self._last_request_time = None
    self._request_handler_task = None
    self._identity = uuid.uuid4()

  # -------------------------------------------- #
  # Connection
  # -------------------------------------------- #
  def connection_made(self, transport):
    self.connections.add(self)
    self._timeout_handler = self.loop.call_later(self.request_timeout, self.connection_timeout)
    self.transport = transport
    self._last_request_time = datetime.datetime.utcnow()

  def connection_lost(self, exc):
    self.connections.discard(self)
    self._timeout_handler.cancel()
    self.cleanup()

  def connection_timeout(self):
    time_elapsed = datetime.datetime.utcnow() - self._last_request_time
    try:
      if time_elapsed.seconds < self.request_timeout:
        time_left = self.request_timeout - time_elapsed.seconds
        self._timeout_handler = self.loop.call_later(time_left, self.connection_timeout)
    except Exception as err:
      print(err)
      import ipdb; ipdb.set_trace()
      pass

    else:
      if self._request_handler_task:
        self._request_handler_task.cancel()

      exception = panic_exceptions.RequestTimeout('Request Timeout')
      self.write_error(exception)

  # -------------------------------------------- #
  # Parsing
  # -------------------------------------------- #

  def data_received(self, data):
    # Check for the request itself getting too large and exceeding
    # memory limits
    # TODO: ^
    self._total_request_size += len(data)

    # Create parser if this is the first time we're receiving data
    if self.parser is None:
      assert self.request is None
      self.headers = panic_datatypes.HTTPHeaders()
      self.parser = HttpRequestParser(self)

    # Parse request chunk or close connection
    try:
      self.parser.feed_data(data)
    except HttpParserError as err:
      import ipdb;ipdb.set_trace()
      exception = panic_exceptions.InvalidUsage('Bad Request')
      self.write_error(exception)

  def on_url(self, url):
    self.url = url

  def on_header(self, name, value):
    #if name == b'Content-Length' and int(value) > 1000:
    #  exception = PayloadTooLarge('Payload Too Large')
    #  self.write_error(exception)

    self.headers.append(name.decode(), value.decode('utf-8'))

  def on_headers_complete(self):
    remote_addr = self.transport.get_extra_info('peername')
    if remote_addr:
      self.headers.append(remote_addr[0], str(remote_addr[1]))

    self.request = panic_request.Request(
      url = self.url,
      headers = self.headers,
      version = self.parser.get_http_version(),
      method = panic_datatypes.HTTPMethod.Match(self.parser.get_method().decode())
    )

  def on_body(self, body):
    self.request.body.append(body)

  def on_message_complete(self):
    self._request_handler_task = self.loop.create_task(self.request_handler(self.request, self.write_response))

  # -------------------------------------------- #
  # Responding
  # -------------------------------------------- #

  def write_response(self, response):
    if self.parser:
      keep_alive = self.parser.should_keep_alive() and not self.signal.stopped

    else:
      keep_alive = False

    try:
      self.transport.write(response.output(getattr(self.request, 'version', '1.1')))
    except RuntimeError as err:
      logger.error(err)

    except Exception as err:
      import ipdb; ipdb.set_trace()
      pass

    if keep_alive:
      self._last_request_time = datetime.datetime.utcnow()
      self.cleanup()

    else:
      self.transport.close()

  def write_error(self, exception):
    try:
      response = self.params.error_handler(self.request, exception)
      version = self.request.version if self.request else '1.1'
      self.transport.write(response.output(float(version)))
      self.transport.close()
    except panic_exceptions.RequestTimeout:
      exception = panic_exceptions.ServerError('RT')
      exception.status = 408
      response = self.params.error_handler(self.request, exception)
      version = self.request.version if self.request else '1.1'
      self.transport.write(response.output(float(version)))
      self.transport.close()
      #self.write_error(exception)

    except Exception as err:
      # logger.exception(err)
      import traceback
      traceback.print_stack()
      import ipdb;ipdb.set_trace()
      import sys; sys.exit(1)
      self.bail_out("Writing error failed, connection closed {}".format(e))

  def bail_out(self, message):
    exception = ServerError(message)
    self.write_error(exception)
    logger.error(message)

  def cleanup(self):
    self.parser = None
    self.request = None
    self.url = None
    self.headers = None
    self._request_handler_task = None

  def close_if_idle(self):
    """
    Close the connection if a request is not being sent or received
    :return: boolean - True if closed, false if staying open
    """
    if not self.parser:
      self.transport.close()
      return True

    return False
示例#10
0
class HttpProtocol(asyncio.Protocol):
    """
    HTTP 协议
    """
    # 插槽
    __slots__ = (
        # 事件循环, 连接
        'loop', 'transport', 'connections', 'signal',
        # 请求参数
        'parser', 'request', 'url', 'headers',
        # 请求配置
        'request_handler', 'request_timeout', 'request_max_size',
        # 连接管理
        '_total_request_size', '_timeout_handler', '_last_communication_time')

    def __init__(self, *, loop, request_handler, error_handler,
                 signal=Signal(), connections={}, request_timeout=60,
                 request_max_size=None):
        self.loop = loop                            # 事件循环
        self.transport = None
        self.request = None                         # 请求
        self.parser = None
        self.url = None                             # 预留的路径
        self.headers = None                         # 请求头
        self.signal = signal                        # 标志是否结束
        self.connections = connections              # 连接集合
        self.request_handler = request_handler      # 请求处理器
        self.error_handler = error_handler          # 出错处理器
        self.request_timeout = request_timeout      # 请求超时时间
        self.request_max_size = request_max_size    # 请求最大大小
        self._total_request_size = 0
        self._timeout_handler = None
        self._last_request_time = None
        self._request_handler_task = None

    # -------------------------------------------- #
    # 连接部分
    # -------------------------------------------- #

    def connection_made(self, transport):
        """
        创建连接
        """
        self.connections.add(self)
        self._timeout_handler = self.loop.call_later(
            self.request_timeout, self.connection_timeout)
        self.transport = transport
        self._last_request_time = current_time

    def connection_lost(self, exc):
        """
        丢失连接
        """
        self.connections.discard(self)
        self._timeout_handler.cancel()
        self.cleanup()

    def connection_timeout(self):
        """
        连接超时
        """
        time_elapsed = current_time - self._last_request_time   # 计算与上次请求间隔
        if time_elapsed < self.request_timeout: # 未超时
            time_left = self.request_timeout - time_elapsed
            self._timeout_handler = \
                self.loop.call_later(time_left, self.connection_timeout)
        else:   # 超时
            if self._request_handler_task:
                self._request_handler_task.cancel()
            exception = RequestTimeout('Request Timeout')
            self.write_error(exception)

    # -------------------------------------------- #
    # 解析部分
    # -------------------------------------------- #

    def data_received(self, data):
        """
        接受数据
        """
        self._total_request_size += len(data)
        if self._total_request_size > self.request_max_size:    # 请求数据过大
            # 在`exceptions.py`中添加 PayloadTooLarge 错误
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)

        # 如果是第一次接受数据,创建 parser
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = HttpRequestParser(self)

        # 解析请求
        try:
            self.parser.feed_data(data)
        except HttpParserError:
            exception = InvalidUsage('Bad Request')
            self.write_error(exception)

    def on_url(self, url):
        """
        获得 url
        """
        self.url = url

    def on_header(self, name, value):
        """
        补全 HTTP 请求的 head 信息
        """
        if name == b'Content-Length' and int(value) > self.request_max_size:
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)

        self.headers.append((name.decode(), value.decode('utf-8')))

    def on_headers_complete(self):
        """
        写入 HTTP 请求 head 信息
        """
        # 远程地址
        remote_addr = self.transport.get_extra_info('peername')
        if remote_addr:
            self.headers.append(('Remote-Addr', '%s:%s' % remote_addr))

        # HTTP 请求 head
        self.request = Request(
            url_bytes=self.url,
            headers=CIMultiDict(self.headers),
            version=self.parser.get_http_version(),
            method=self.parser.get_method().decode()
        )

    def on_body(self, body):
        """
        写入 HTTP 请求 body
        """
        if self.request.body:
            self.request.body += body
        else:
            self.request.body = body

    def on_message_complete(self):
        """
        创建 task
        """
        self._request_handler_task = self.loop.create_task(
            self.request_handler(self.request, self.write_response))

    # -------------------------------------------- #
    # 响应部分
    # -------------------------------------------- #

    def write_response(self, response):
        """
        编写 HTTP 响应
        """
        try:
            keep_alive = self.parser.should_keep_alive() \
                            and not self.signal.stopped
            # 输出响应
            self.transport.write(
                response.output(
                    self.request.version, keep_alive, self.request_timeout))
            if not keep_alive:
                self.transport.close()
            else:
                # 记录接收到的数据
                self._last_request_time = current_time
                self.cleanup()
        except Exception as e:
            self.bail_out(
                "Writing response failed, connection closed {}".format(e))

    def write_error(self, exception):
        """
        编写 HTTP 错误响应
        """
        try:
            response = self.error_handler.response(self.request, exception)
            version = self.request.version if self.request else '1.1'
            self.transport.write(response.output(version))
            self.transport.close()
        except Exception as e:
            self.bail_out(
                "Writing error failed, connection closed {}".format(e))

    def bail_out(self, message):
        """
        记录异常辅助方法
        """
        exception = ServerError(message)
        self.write_error(exception)
        log.error(message)

    def cleanup(self):
        """
        清空请求字段
        """
        self.parser = None
        self.request = None
        self.url = None
        self.headers = None
        self._request_handler_task = None
        self._total_request_size = 0

    def close_if_idle(self):
        """
        若没有发生或接受请求,则关闭连接
        :return: boolean - True 为关, false 为保持开启
        """
        if not self.parser:
            self.transport.close()
            return True
        return False