Ejemplo n.º 1
0
 def __init__(self, event_loop=None, route=None, re_route=None):
     self._route = route
     self._re_route = re_route
     self._loop = event_loop
     self._transport = None
     self._parser = HttpRequestParser(self)
     self._request = Request()
Ejemplo n.º 2
0
    def __init__(
            self,
            *,
            loop,
            conns,  # server.conns
            router,  # path handler mgr
            request_limit_size=1024 * 1024 * 1,  # 1M
            request_timeout=60,
            response_timeout=60,
            keep_alive=10):

        self.loop = loop
        self.conns = conns
        self.router = router
        self.transport = None
        self.request = Request(self)
        self.response = Response(self)
        self.response.set_keep_alive()
        self.parser = HttpRequestParser(self)

        self.request_limit_size = request_limit_size
        self.request_cur_size = 0
        self.request_timeout = request_timeout
        self.response_timeout = response_timeout
        self.keep_alive = keep_alive
        self.last_request_time = 0

        self.remote_addr = None
        self.request_timeout_task = None
        self.response_timeout_task = None
        self.conn_timeout_task = None

        self.route_mgr = None
Ejemplo n.º 3
0
 def __init__(self, *, parent, loop):
     self._parent = parent
     self._transport = None
     self.data = None
     self.http_parser = HttpRequestParser(self)
     self.request = None
     self._loop = loop
Ejemplo n.º 4
0
    def data_received(self, data):
        # Check for the request itself getting too large and exceeding
        # memory limits
        self._total_request_size += len(data)
        if self._total_request_size > self.request_max_size:
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)

        # Create parser if this is the first time we're receiving data
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = HttpRequestParser(self)

        # requests count
        self.state['requests_count'] = self.state['requests_count'] + 1

        # Parse request chunk or close connection
        try:
            self.parser.feed_data(data)
        except HttpParserError:
            message = 'Bad Request'
            if self._debug:
                message += '\n' + traceback.format_exc()
            exception = InvalidUsage(message)
            self.write_error(exception)
Ejemplo n.º 5
0
    def data_received(self, data):
        self._raw.put_nowait(data)  # 原始数据都保存
        # Check for the request itself getting too large and exceeding
        # memory limits
        self._total_request_size += len(data)
        if self._total_request_size > self.request_max_size:
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)
        if self._is_proxy:  # 如果是代理过程那么就不要再去parse了
            return
        # Create parser if this is the first time we're receiving data
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = HttpRequestParser(self)

        # requests count
        self.state['requests_count'] = self.state['requests_count'] + 1

        # Parse request chunk or close connection
        try:
            self.parser.feed_data(data)
        except HttpParserInvalidMethodError as e:  # CONNECT 包
            pass
        except HttpParserUpgrade:  # CONNECT 包
            pass
        except HttpParserError:
            message = 'Bad Request'
            if self._debug:
                message += '\n' + traceback.format_exc()
            exception = InvalidUsage(message)
            self.write_error(exception)
Ejemplo n.º 6
0
    def data_received(self, data):
        """接受到HTTP请求时调用 ."""
        # Check for the request itself getting too large and exceeding
        # memory limits
        self._total_request_size += len(data)
        # 限制请求内容的长度
        if self._total_request_size > self.request_max_size:
            self.write_error(PayloadTooLarge("Payload Too Large"))

        # Create parser if this is the first time we're receiving data
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = HttpRequestParser(self)

        # requests count
        self.state["requests_count"] = self.state["requests_count"] + 1

        # Parse request chunk or close connection
        try:
            # 解析HTTP协议
            self.parser.feed_data(data)
        except HttpParserError:
            # 如果不是合法的HTTP协议,返回400错误
            message = "Bad Request"
            if self.app.debug:
                message += "\n" + traceback.format_exc()
            self.write_error(InvalidUsage(message))
Ejemplo n.º 7
0
class Server(asyncio.Protocol, HttpParserMixin):
    def __init__(self, loop, handler, app):
        self._loop = loop
        self._app = app
        self._encoding = "utf-8"
        self._url = None
        self._request = None
        self._body = None
        self._request_class = Request
        self._request_handler = handler
        self._request_handler_task = None
        self._transport = None
        self._request_parser = HttpRequestParser(self)
        self._headers = {}

    def connection_made(self, transport):
        self._transport = transport

    def connection_lost(self, *args):
        self._transport = None

    def response_writer(self, response):
        self._transport.write(str(response).encode(self._encoding))
        self._transport.close()

    def data_received(self, data):
        self._request_parser.feed_data(data)
Ejemplo n.º 8
0
 def __init__(self, socket):
     self.complete = False
     self.headers_complete = False
     self.parser = HttpRequestParser(self)
     self.request = None
     self.socket = socket
     self.reader = self._reader()
Ejemplo n.º 9
0
class Proto(asyncio.Protocol):
    def connection_made(self, transport):
        self.transport = transport
        self.parser = HttpRequestParser(self)
        self.body = None

    def connection_lost(self, exc):
        self.transport = None

    def data_received(self, data):
        self.parser.feed_data(data)

    def on_body(self, body):
        self.body = body

    def on_message_complete(self):
        request = ujson.loads(self.body)
        response = {'result': request['a'] + request['b']}
        self.send(ujson.dumps(response).encode('utf8'))
        self.body = None

    def send(self, data):
        response = b'HTTP/1.1 200 OK\nContent-Length: ' + str(
            len(data)).encode() + b'\n\n' + data
        self.transport.write(response)

        if not self.parser.should_keep_alive():
            #print('no keep-alive')
            self.transport.close()
            self.transport = None
Ejemplo n.º 10
0
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.parser = HttpRequestParser(self)

        self._url: bytes = b""
        self._headers: tp.Dict[bytes, bytes] = {}
        self._body: bytes = b""
        self._parsed = False
Ejemplo n.º 11
0
 def __init__(self, loop):
     self._loop = loop
     self._encoding = "utf-8"
     self._url = None
     self._headers = {}
     self._body = None
     self._transport = None
     self._request_parser = HttpRequestParser(self)
Ejemplo n.º 12
0
async def handler(conn):
    print(conn)

    request = await readall_from_socket(conn)
    # print(request)

    http = HTTP()
    parser = HttpRequestParser(http)
    parser.feed_data(request)
    method = parser.get_method().decode()
    url_path = http.url
    print(method, url_path)

    # Attempt 1
    # with open('requirements.lock', 'rb') as f:
    #     # only python3.5+, and do not support non-bloking socket
    #     conn.sendfile(f)

    # Attempt 2
    fl = None
    for mp in mapping:
        fl = mp.file(url_path or '')
        if fl:
            break
    if fl:
        if fl.exists():
            filepath = fl.path
            if fl.is_file():
                with open(filepath, 'rb') as f:
                    blocksize = os.path.getsize(filepath)
                    conn.send(b'HTTP/1.1 200 OK\r\n')
                    conn.send(f'Content-Length: {blocksize}\r\n'.encode('ascii'))
                    mime = mimetypes.guess_type(filepath)[0]
                    # mime = "text/plain" if mime else "application/octet-stream"
                    mime = mime or "application/octet-stream"
                    conn.send(
                        f'Content-Type: {mime}; charset=utf-8\r\n'.encode('ascii')
                    )
                    # conn.send(b'Transfer-Encoding: chunked')
                    conn.send(b'\r\n')
                    _ = sendfile(conn.fileno(), f.fileno(), 0, blocksize)
            elif fl.is_dir():
                files = fl.listdir()
                body = '<br/>'.join(
                    f'<a href="{url_path.rstrip("/")}/{x.basename}{"/" if x.is_dir() else ""}">{x.basename}{"/" if x.is_dir() else ""}</a>'
                    for x in files
                ).encode('utf8')
                conn.send(b'HTTP/1.1 200 OK\r\n')
                conn.send(f'Content-Length: {len(body)}\r\n'.encode('ascii'))
                conn.send(b'Content-Type: text/html; charset=utf-8\r\n')
                conn.send(b'\r\n')
                conn.sendall(body)

    conn.send(b'HTTP/1.1 404 Not Found\r\n')
    conn.send(b'Content-Type: text/plain; charset=utf-8\r\n')
    conn.send(b'\r\n')
    conn.sendall(b'Not Found')
    conn.close()
Ejemplo n.º 13
0
    def handle_client(self):
        #wait for the message to arrive
        if self.parser is None:
            self.headers = []
            self.parser = HttpRequestParser(self)

        
        msg = self.client.recv(1000)
        self.parser.feed_data(msg)
Ejemplo n.º 14
0
 def __init__(self, *, loop, config, ssl_forward=False):
     self.config = config
     self._transport = None
     self.data = None
     self.http_parser = HttpRequestParser(self)
     self.client = None
     self._loop = loop
     self._url = None
     self._headers = None
     self.ssl_forward = ssl_forward
Ejemplo n.º 15
0
 def __init__(self, loop, handler, app):
     self._loop = loop
     self._app = app
     self._encoding = "utf-8"
     self._url = None
     self._request = None
     self._body = None
     self._request_class = Request
     self._request_handler = handler
     self._request_handler_task = None
     self._transport = None
     self._request_parser = HttpRequestParser(self)
     self._headers = {}
Ejemplo n.º 16
0
    async def accept2(self, sock):
        fut = future(self.loop)
        conn, addr = sock.accept()  # Should be ready
        conn.setblocking(False)
        self.recv(fut, conn, 1024*1024*10)
        data = await fut
        print("\ndata:", type(data), '\n', data, '\n')
        self.parser = HttpRequestParser(self)
        self.parser.feed_data(data)
        print('recive data: ', self.request.body)
        response = self.handle_request(self.request)
        await self.make_response(conn, self.request, response)

        self.remove_reader(conn)
Ejemplo n.º 17
0
 def data_received(self, data):
     hrp = HttpRequestParser(self)
     try:
         hrp.feed_data(data)
         self.request_version = hrp.get_http_version()
         self.send_response()
     except HttpParserInvalidMethodError:
         self.data = data
         self.error = True
         self.bad_request()
     except HttpParserError:
         self.data = data
         self.error = True
         self.bad_request()
Ejemplo n.º 18
0
    def __init__(self, connection):
        self._connection = connection
        self._parser = HttpRequestParser(self)
        self._body_buffer = b''
        self._headers_complete = False
        self._is_body_complete = False

        self.version = None
        self.keep_alive = None
        self.upgrade = None
        self.address = connection.address
        self.raw_method = b''
        self.raw_headers = CIMultiDict()
        self.raw_query = CIMultiDict()
        self.raw_path = b''

        self.method = ''
        self.host = 'localhost'
        self.port = 80
        self.headers = RequestHeaders()
        self.query = CIMultiDict()
        self.cookies = {}
        self.path = ''
        self.content_type_main = 'application'
        self.content_type_sub = 'octet-stream'
        self.content_type_params = {}
        self.content_charset = 'ascii'

        self._body_length = 2**32
        self._body_position = 0
        self._is_body_complete = False
        self._body = None
        self._text = None
        self._json = None
        self._form = None
Ejemplo n.º 19
0
    def data_received(self, data):
        # Check for the request itself getting too large and exceeding
        # memory limits
        self._total_request_size += len(data)
        if self._total_request_size > self.request_max_size:
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)

        # Create parser if this is the first time we're receiving data
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = HttpRequestParser(self)

        # requests count
        self.state['requests_count'] = self.state['requests_count'] + 1

        # Parse request chunk or close connection
        try:
            self.parser.feed_data(data)
        except HttpParserError:
            message = 'Bad Request'
            if self._debug:
                message += '\n' + traceback.format_exc()
            exception = InvalidUsage(message)
            self.write_error(exception)
Ejemplo n.º 20
0
 def data_received(self, data):
     pprint(data.decode("utf-8"))
     self.contentlength += len(data)
     if not self.parse:
         # HttpRequestParser响应当前对象的以下方法
         # - on_url(url:byte)
         # - on_header(name: bytes, value: bytes)
         # - on_headers_complete()
         # - on_body(body: bytes)
         # - on_message_complete()
         # get_http_version(self) -> str
         # def should_keep_alive(self) -> bool:
         self.parse = HttpRequestParser(self)
     try:
         self.parse.feed_data(data)
     except HttpRequestParser as e:
         pass
Ejemplo n.º 21
0
    def data_received(self, fd):
        # DEBUG_MSG("data_received, fd: %s, id: %s" % (fd, id(self)))
        if self.parser is None:
            self.parser = HttpRequestParser(self)
            self.headers = []

        while True:
            try:
                data = self._sock.recv(4096)
                # DEBUG_MSG("data_received, fd: %s, data len: %s" %
                #           (fd, len(data)))
                # 客户端关闭了链接
                if not data:
                    ERROR_MSG("data_received, data len is 0, close")
                    self.close()
                    return
                self._read_buffer += data
            except (socket.error, IOError, OSError) as e:
                _errno = errno_from_exception(e)
                # 系统调用被signal中断
                if _errno == errno.EINTR:
                    continue
                # 此次的recv数据读完了,recv抛出下面的异常
                # 此次recv数据读完了,但不表示该连接的一次发包数据读完了,一次发包可能
                # 触发多次epoll 事件
                elif _errno in ERRNO_WOULDBLOCK:
                    DEBUG_MSG("data_received, done")
                    break

                ERROR_MSG("socket recv error: %s" % str(e))
                self.close()
                return
            except Exception as e:
                ERROR_MSG("data_received exception, e: %s" % str(e))
                return

        if self._read_buffer:
            try:
                self.parser.feed_data(bytes(self._read_buffer))
                # 注意要在feed之后清,这次请求可能会被feed 多次,因此每次feed 完要清掉
                self._read_buffer.clear()
            except HttpParserError as e:
                ERROR_MSG(
                    "Connection::data_received feed_data error. error: %s"
                    " \n %s \n id: %s" %
                    (str(e), traceback.format_exc(), id(self)))
Ejemplo n.º 22
0
  def data_received(self, data):
    # Check for the request itself getting too large and exceeding
    # memory limits
    # TODO: ^
    self._total_request_size += len(data)

    # Create parser if this is the first time we're receiving data
    if self.parser is None:
      assert self.request is None
      self.headers = panic_datatypes.HTTPHeaders()
      self.parser = HttpRequestParser(self)

    # Parse request chunk or close connection
    try:
      self.parser.feed_data(data)
    except HttpParserError as err:
      import ipdb;ipdb.set_trace()
      exception = panic_exceptions.InvalidUsage('Bad Request')
      self.write_error(exception)
Ejemplo n.º 23
0
    async def _parse_request(self, request_reader: asyncio.StreamReader,
                             response_writer: asyncio.StreamWriter) -> Request:
        """parse data from StreamReader and build the request object
        """
        limit = 2**16
        req = Request()
        parser = HttpRequestParser(req)

        while True:
            data = await request_reader.read(limit)
            parser.feed_data(data)
            if req.finished or not data:
                break
            elif req.needs_write_continue:
                response_writer.write(b'HTTP/1.1 100 (Continue)\r\n\r\n')
                req.reset_state()

        req.method = touni(parser.get_method()).upper()
        return req
Ejemplo n.º 24
0
class Server(asyncio.Protocol, HttpParserMixin):
    def __init__(self, loop):
        self._loop = loop
        self._encoding = "utf-8"
        self._url = None
        self._headers = {}
        self._body = None
        self._transport = None
        self._request_parser = HttpRequestParser(self)

    def connection_made(self, transport):
        self._transport = transport

    def connection_lost(self, *args):
        self._transport = None

    def data_received(self, data):
        # Pass data to our parser
        self._request_parser.feed_data(data)
Ejemplo n.º 25
0
    def data_received(self, data):
        # Check for the request itself getting too large and exceeding
        # memory limits
        self._total_request_size += len(data)
        if self._total_request_size > self.request_max_size:
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)

        # Create parser if this is the first time we're receiving data
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = HttpRequestParser(self)

        # Parse request chunk or close connection
        try:
            self.parser.feed_data(data)
        except HttpParserError:
            exception = InvalidUsage('Bad Request')
            self.write_error(exception)
Ejemplo n.º 26
0
    def data_received(self, data):
        # Check for the request itself getting too large and exceeding
        # memory limits
        self._total_request_size += len(data)
        if self._total_request_size > self.request_max_size:
            return self.bail_out(
                "Request too large ({}), connection closed".format(
                    self._total_request_size))

        # Create parser if this is the first time we're receiving data
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = HttpRequestParser(self)

        # Parse request chunk or close connection
        try:
            self.parser.feed_data(data)
        except HttpParserError as e:
            self.bail_out(
                "Invalid request data, connection closed ({})".format(e))
Ejemplo n.º 27
0
  def __init__(self, params: panic_datatypes.ServerParams):
    self.enabled = False
    self.params = params
    self.websocket = None
    self.transport = None
    self.timeout = params.request_timeout or 10
    self.max_size = 2 ** 20
    self.max_queue = 2 ** 5
    self.read_limit =  2 ** 16
    self.write_limit = 2 ** 16

    self.url = None
    self.connections = params.connections

    self.request = None
    self.headers = panic_datatypes.HTTPHeaders()
    self.parser = HttpRequestParser(self)

    self._last_request_time = None

    self._identity = uuid.uuid4()
Ejemplo n.º 28
0
    def data_received(self, data):
        """
        接受数据
        """
        self._total_request_size += len(data)
        if self._total_request_size > self.request_max_size:    # 请求数据过大
            # 在`exceptions.py`中添加 PayloadTooLarge 错误
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)

        # 如果是第一次接受数据,创建 parser
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = HttpRequestParser(self)

        # 解析请求
        try:
            self.parser.feed_data(data)
        except HttpParserError:
            exception = InvalidUsage('Bad Request')
            self.write_error(exception)
Ejemplo n.º 29
0
class HTTPProtocol():
    def __init__(self, future=None):
        self.parser = HttpRequestParser(self)
        self.headers = {}
        self.body = b""
        self.url = b""
        self.future = future

    def on_url(self, url: bytes):
        self.url = url

    def on_header(self, name: bytes, value: bytes):
        self.headers[name] = value

    def on_body(self, body: bytes):
        self.body = body

    def on_message_complete(self):
        self.future.set_result(self)

    def feed_data(self, data):
        self.parser.feed_data(data)
Ejemplo n.º 30
0
    async def _parse_request(self, request_reader, response_writer):
        limit = 2 ** 16
        req = Request()
        parser = HttpRequestParser(req)

        while True:
            data = await request_reader.read(limit)
            parser.feed_data(data)
            if req.finished or not data:
                break
            elif req.needs_write_continue:
                response_writer.write(b'HTTP/1.1 100 (Continue)\r\n\r\n')
                req.reset_state()

        if req.path is None:
            # connected without a formed HTTP request
            return

        handler, args = self.get_handler(req.path)

        req.method = parser.get_method().decode().upper()
        req.args = args
        return req, handler
Ejemplo n.º 31
0
 def data_received(self, data: bytes) -> None:
     """
     socket 收到数据
     """
     # print(data)
     if self._request is None:
         # future = self._loop.create_future()
         self._request = Request(
             cast(asyncio.AbstractEventLoop, self._loop),
             self.complete_handle,
             cast(asyncio.Transport, self._transport),
             charset=self._requset_charset,
         )
         self._request.parser = HttpRequestParser(self._request)
     self._request.feed_data(data)
Ejemplo n.º 32
0
    def data_received(self, data):
        # Check for the request itself getting too large and exceeding
        # memory limits
        self._total_request_size += len(data)
        if self._total_request_size > self.request_max_size:
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)

        # Create parser if this is the first time we're receiving data
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = HttpRequestParser(self)

        # Parse request chunk or close connection
        try:
            self.parser.feed_data(data)
        except HttpParserError:
            exception = InvalidUsage('Bad Request')
            self.write_error(exception)
Ejemplo n.º 33
0
class HttpProtocol(asyncio.Protocol):
    __slots__ = (
        # event loop, connection
        'loop', 'transport', 'connections', 'signal',
        # request params
        'parser', 'request', 'url', 'headers',
        # request config
        'request_handler', 'request_timeout', 'response_timeout',
        'keep_alive_timeout', 'request_max_size', 'request_class',
        'is_request_stream', 'router',
        # enable or disable access log purpose
        'access_log',
        # connection management
        '_total_request_size', '_request_timeout_handler',
        '_response_timeout_handler', '_keep_alive_timeout_handler',
        '_last_request_time', '_last_response_time', '_is_stream_handler')

    def __init__(self, *, loop, request_handler, error_handler,
                 signal=Signal(), connections=set(), request_timeout=60,
                 response_timeout=60, keep_alive_timeout=5,
                 request_max_size=None, request_class=None, access_log=True,
                 keep_alive=True, is_request_stream=False, router=None,
                 state=None, debug=False, **kwargs):
        self.loop = loop
        self.transport = None
        self.request = None
        self.parser = None
        self.url = None
        self.headers = None
        self.router = router
        self.signal = signal
        self.access_log = access_log
        self.connections = connections
        self.request_handler = request_handler
        self.error_handler = error_handler
        self.request_timeout = request_timeout
        self.response_timeout = response_timeout
        self.keep_alive_timeout = keep_alive_timeout
        self.request_max_size = request_max_size
        self.request_class = request_class or Request
        self.is_request_stream = is_request_stream
        self._is_stream_handler = False
        self._total_request_size = 0
        self._request_timeout_handler = None
        self._response_timeout_handler = None
        self._keep_alive_timeout_handler = None
        self._last_request_time = None
        self._last_response_time = None
        self._request_handler_task = None
        self._request_stream_task = None
        self._keep_alive = keep_alive
        self._header_fragment = b''
        self.state = state if state else {}
        if 'requests_count' not in self.state:
            self.state['requests_count'] = 0
        self._debug = debug

    @property
    def keep_alive(self):
        return (
            self._keep_alive and
            not self.signal.stopped and
            self.parser.should_keep_alive())

    # -------------------------------------------- #
    # Connection
    # -------------------------------------------- #

    def connection_made(self, transport):
        self.connections.add(self)
        self._request_timeout_handler = self.loop.call_later(
            self.request_timeout, self.request_timeout_callback)
        self.transport = transport
        self._last_request_time = current_time

    def connection_lost(self, exc):
        self.connections.discard(self)
        if self._request_timeout_handler:
            self._request_timeout_handler.cancel()
        if self._response_timeout_handler:
            self._response_timeout_handler.cancel()
        if self._keep_alive_timeout_handler:
            self._keep_alive_timeout_handler.cancel()

    def request_timeout_callback(self):
        # See the docstring in the RequestTimeout exception, to see
        # exactly what this timeout is checking for.
        # Check if elapsed time since request initiated exceeds our
        # configured maximum request timeout value
        time_elapsed = current_time - self._last_request_time
        if time_elapsed < self.request_timeout:
            time_left = self.request_timeout - time_elapsed
            self._request_timeout_handler = (
                self.loop.call_later(time_left,
                                     self.request_timeout_callback)
            )
        else:
            if self._request_stream_task:
                self._request_stream_task.cancel()
            if self._request_handler_task:
                self._request_handler_task.cancel()
            try:
                raise RequestTimeout('Request Timeout')
            except RequestTimeout as exception:
                self.write_error(exception)

    def response_timeout_callback(self):
        # Check if elapsed time since response was initiated exceeds our
        # configured maximum request timeout value
        time_elapsed = current_time - self._last_request_time
        if time_elapsed < self.response_timeout:
            time_left = self.response_timeout - time_elapsed
            self._response_timeout_handler = (
                self.loop.call_later(time_left,
                                     self.response_timeout_callback)
            )
        else:
            try:
                raise ServiceUnavailable('Response Timeout')
            except ServiceUnavailable as exception:
                self.write_error(exception)

    def keep_alive_timeout_callback(self):
        # Check if elapsed time since last response exceeds our configured
        # maximum keep alive timeout value
        time_elapsed = current_time - self._last_response_time
        if time_elapsed < self.keep_alive_timeout:
            time_left = self.keep_alive_timeout - time_elapsed
            self._keep_alive_timeout_handler = (
                self.loop.call_later(time_left,
                                     self.keep_alive_timeout_callback)
            )
        else:
            logger.info('KeepAlive Timeout. Closing connection.')
            self.transport.close()
            self.transport = None

    # -------------------------------------------- #
    # Parsing
    # -------------------------------------------- #

    def data_received(self, data):
        # Check for the request itself getting too large and exceeding
        # memory limits
        self._total_request_size += len(data)
        if self._total_request_size > self.request_max_size:
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)

        # Create parser if this is the first time we're receiving data
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = HttpRequestParser(self)

        # requests count
        self.state['requests_count'] = self.state['requests_count'] + 1

        # Parse request chunk or close connection
        try:
            self.parser.feed_data(data)
        except HttpParserError:
            message = 'Bad Request'
            if self._debug:
                message += '\n' + traceback.format_exc()
            exception = InvalidUsage(message)
            self.write_error(exception)

    def on_url(self, url):
        if not self.url:
            self.url = url
        else:
            self.url += url

    def on_header(self, name, value):
        self._header_fragment += name

        if value is not None:
            if self._header_fragment == b'Content-Length' \
                    and int(value) > self.request_max_size:
                exception = PayloadTooLarge('Payload Too Large')
                self.write_error(exception)
            try:
                value = value.decode()
            except UnicodeDecodeError:
                value = value.decode('latin_1')
            self.headers.append(
                    (self._header_fragment.decode().casefold(), value))

            self._header_fragment = b''

    def on_headers_complete(self):
        self.request = self.request_class(
            url_bytes=self.url,
            headers=CIDict(self.headers),
            version=self.parser.get_http_version(),
            method=self.parser.get_method().decode(),
            transport=self.transport
        )
        # Remove any existing KeepAlive handler here,
        # It will be recreated if required on the new request.
        if self._keep_alive_timeout_handler:
            self._keep_alive_timeout_handler.cancel()
            self._keep_alive_timeout_handler = None
        if self.is_request_stream:
            self._is_stream_handler = self.router.is_stream_handler(
                self.request)
            if self._is_stream_handler:
                self.request.stream = asyncio.Queue()
                self.execute_request_handler()

    def on_body(self, body):
        if self.is_request_stream and self._is_stream_handler:
            self._request_stream_task = self.loop.create_task(
                self.request.stream.put(body))
            return
        self.request.body.append(body)

    def on_message_complete(self):
        # Entire request (headers and whole body) is received.
        # We can cancel and remove the request timeout handler now.
        if self._request_timeout_handler:
            self._request_timeout_handler.cancel()
            self._request_timeout_handler = None
        if self.is_request_stream and self._is_stream_handler:
            self._request_stream_task = self.loop.create_task(
                self.request.stream.put(None))
            return
        self.request.body = b''.join(self.request.body)
        self.execute_request_handler()

    def execute_request_handler(self):
        self._response_timeout_handler = self.loop.call_later(
            self.response_timeout, self.response_timeout_callback)
        self._last_request_time = current_time
        self._request_handler_task = self.loop.create_task(
            self.request_handler(
                self.request,
                self.write_response,
                self.stream_response))

    # -------------------------------------------- #
    # Responding
    # -------------------------------------------- #
    def log_response(self, response):
        if self.access_log:
            extra = {
                'status': getattr(response, 'status', 0),
            }

            if isinstance(response, HTTPResponse):
                extra['byte'] = len(response.body)
            else:
                extra['byte'] = -1

            extra['host'] = 'UNKNOWN'
            if self.request is not None:
                if self.request.ip:
                    extra['host'] = '{0[0]}:{0[1]}'.format(self.request.ip)

                extra['request'] = '{0} {1}'.format(self.request.method,
                                                    self.request.url)
            else:
                extra['request'] = 'nil'

            access_logger.info('', extra=extra)

    def write_response(self, response):
        """
        Writes response content synchronously to the transport.
        """
        if self._response_timeout_handler:
            self._response_timeout_handler.cancel()
            self._response_timeout_handler = None
        try:
            keep_alive = self.keep_alive
            self.transport.write(
                response.output(
                    self.request.version, keep_alive,
                    self.keep_alive_timeout))
            self.log_response(response)
        except AttributeError:
            logger.error('Invalid response object for url %s, '
                         'Expected Type: HTTPResponse, Actual Type: %s',
                         self.url, type(response))
            self.write_error(ServerError('Invalid response type'))
        except RuntimeError:
            if self._debug:
                logger.error('Connection lost before response written @ %s',
                             self.request.ip)
            keep_alive = False
        except Exception as e:
            self.bail_out(
                "Writing response failed, connection closed {}".format(
                    repr(e)))
        finally:
            if not keep_alive:
                self.transport.close()
                self.transport = None
            else:
                self._keep_alive_timeout_handler = self.loop.call_later(
                    self.keep_alive_timeout,
                    self.keep_alive_timeout_callback)
                self._last_response_time = current_time
                self.cleanup()

    async def stream_response(self, response):
        """
        Streams a response to the client asynchronously. Attaches
        the transport to the response so the response consumer can
        write to the response as needed.
        """
        if self._response_timeout_handler:
            self._response_timeout_handler.cancel()
            self._response_timeout_handler = None
        try:
            keep_alive = self.keep_alive
            response.transport = self.transport
            await response.stream(
                self.request.version, keep_alive, self.keep_alive_timeout)
            self.log_response(response)
        except AttributeError:
            logger.error('Invalid response object for url %s, '
                         'Expected Type: HTTPResponse, Actual Type: %s',
                         self.url, type(response))
            self.write_error(ServerError('Invalid response type'))
        except RuntimeError:
            if self._debug:
                logger.error('Connection lost before response written @ %s',
                             self.request.ip)
            keep_alive = False
        except Exception as e:
            self.bail_out(
                "Writing response failed, connection closed {}".format(
                    repr(e)))
        finally:
            if not keep_alive:
                self.transport.close()
                self.transport = None
            else:
                self._keep_alive_timeout_handler = self.loop.call_later(
                    self.keep_alive_timeout,
                    self.keep_alive_timeout_callback)
                self._last_response_time = current_time
                self.cleanup()

    def write_error(self, exception):
        # An error _is_ a response.
        # Don't throw a response timeout, when a response _is_ given.
        if self._response_timeout_handler:
            self._response_timeout_handler.cancel()
            self._response_timeout_handler = None
        response = None
        try:
            response = self.error_handler.response(self.request, exception)
            version = self.request.version if self.request else '1.1'
            self.transport.write(response.output(version))
        except RuntimeError:
            if self._debug:
                logger.error('Connection lost before error written @ %s',
                             self.request.ip if self.request else 'Unknown')
        except Exception as e:
            self.bail_out(
                "Writing error failed, connection closed {}".format(
                    repr(e)), from_error=True
            )
        finally:
            if self.parser and (self.keep_alive
                                or getattr(response, 'status', 0) == 408):
                self.log_response(response)
            self.transport.close()

    def bail_out(self, message, from_error=False):
        if from_error or self.transport.is_closing():
            logger.error("Transport closed @ %s and exception "
                         "experienced during error handling",
                         self.transport.get_extra_info('peername'))
            logger.debug('Exception:\n%s', traceback.format_exc())
        else:
            exception = ServerError(message)
            self.write_error(exception)
            logger.error(message)

    def cleanup(self):
        """This is called when KeepAlive feature is used,
        it resets the connection in order for it to be able
        to handle receiving another request on the same connection."""
        self.parser = None
        self.request = None
        self.url = None
        self.headers = None
        self._request_handler_task = None
        self._request_stream_task = None
        self._total_request_size = 0
        self._is_stream_handler = False

    def close_if_idle(self):
        """Close the connection if a request is not being sent or received

        :return: boolean - True if closed, false if staying open
        """
        if not self.parser:
            self.transport.close()
            return True
        return False

    def close(self):
        """
        Force close the connection.
        """
        if self.transport is not None:
            self.transport.close()
            self.transport = None
Ejemplo n.º 34
0
class HttpProtocol(asyncio.Protocol):
    """
    This class provides a basic HTTP implementation of the sanic framework.
    """

    __slots__ = (
        # app
        "app",
        # event loop, connection
        "loop",
        "transport",
        "connections",
        "signal",
        # request params
        "parser",
        "request",
        "url",
        "headers",
        # request config
        "request_handler",
        "request_timeout",
        "response_timeout",
        "keep_alive_timeout",
        "request_max_size",
        "request_buffer_queue_size",
        "request_class",
        "is_request_stream",
        "router",
        "error_handler",
        # enable or disable access log purpose
        "access_log",
        # connection management
        "_total_request_size",
        "_request_timeout_handler",
        "_response_timeout_handler",
        "_keep_alive_timeout_handler",
        "_last_request_time",
        "_last_response_time",
        "_is_stream_handler",
        "_not_paused",
        "_request_handler_task",
        "_request_stream_task",
        "_keep_alive",
        "_header_fragment",
        "state",
        "_debug",
    )

    def __init__(
        self,
        *,
        loop,
        app,
        request_handler,
        error_handler,
        signal=Signal(),
        connections=None,
        request_timeout=60,
        response_timeout=60,
        keep_alive_timeout=5,
        request_max_size=None,
        request_buffer_queue_size=100,
        request_class=None,
        access_log=True,
        keep_alive=True,
        is_request_stream=False,
        router=None,
        state=None,
        debug=False,
        **kwargs
    ):
        self.loop = loop
        self.app = app
        self.transport = None
        self.request = None
        self.parser = None
        self.url = None
        self.headers = None
        self.router = router
        self.signal = signal
        self.access_log = access_log
        self.connections = connections if connections is not None else set()
        self.request_handler = request_handler
        self.error_handler = error_handler
        self.request_timeout = request_timeout
        self.request_buffer_queue_size = request_buffer_queue_size
        self.response_timeout = response_timeout
        self.keep_alive_timeout = keep_alive_timeout
        self.request_max_size = request_max_size
        self.request_class = request_class or Request
        self.is_request_stream = is_request_stream
        self._is_stream_handler = False
        self._not_paused = asyncio.Event(loop=loop)
        self._total_request_size = 0
        self._request_timeout_handler = None
        self._response_timeout_handler = None
        self._keep_alive_timeout_handler = None
        self._last_request_time = None
        self._last_response_time = None
        self._request_handler_task = None
        self._request_stream_task = None
        self._keep_alive = keep_alive
        self._header_fragment = b""
        self.state = state if state else {}
        if "requests_count" not in self.state:
            self.state["requests_count"] = 0
        self._debug = debug
        self._not_paused.set()

    @property
    def keep_alive(self):
        """
        Check if the connection needs to be kept alive based on the params
        attached to the `_keep_alive` attribute, :attr:`Signal.stopped`
        and :func:`HttpProtocol.parser.should_keep_alive`

        :return: ``True`` if connection is to be kept alive ``False`` else
        """
        return (
            self._keep_alive
            and not self.signal.stopped
            and self.parser.should_keep_alive()
        )

    # -------------------------------------------- #
    # Connection
    # -------------------------------------------- #

    def connection_made(self, transport):
        self.connections.add(self)
        self._request_timeout_handler = self.loop.call_later(
            self.request_timeout, self.request_timeout_callback
        )
        self.transport = transport
        self._last_request_time = time()

    def connection_lost(self, exc):
        self.connections.discard(self)
        if self._request_handler_task:
            self._request_handler_task.cancel()
        if self._request_stream_task:
            self._request_stream_task.cancel()
        if self._request_timeout_handler:
            self._request_timeout_handler.cancel()
        if self._response_timeout_handler:
            self._response_timeout_handler.cancel()
        if self._keep_alive_timeout_handler:
            self._keep_alive_timeout_handler.cancel()

    def pause_writing(self):
        self._not_paused.clear()

    def resume_writing(self):
        self._not_paused.set()

    def request_timeout_callback(self):
        # See the docstring in the RequestTimeout exception, to see
        # exactly what this timeout is checking for.
        # Check if elapsed time since request initiated exceeds our
        # configured maximum request timeout value
        time_elapsed = time() - self._last_request_time
        if time_elapsed < self.request_timeout:
            time_left = self.request_timeout - time_elapsed
            self._request_timeout_handler = self.loop.call_later(
                time_left, self.request_timeout_callback
            )
        else:
            if self._request_stream_task:
                self._request_stream_task.cancel()
            if self._request_handler_task:
                self._request_handler_task.cancel()
            self.write_error(RequestTimeout("Request Timeout"))

    def response_timeout_callback(self):
        # Check if elapsed time since response was initiated exceeds our
        # configured maximum request timeout value
        time_elapsed = time() - self._last_request_time
        if time_elapsed < self.response_timeout:
            time_left = self.response_timeout - time_elapsed
            self._response_timeout_handler = self.loop.call_later(
                time_left, self.response_timeout_callback
            )
        else:
            if self._request_stream_task:
                self._request_stream_task.cancel()
            if self._request_handler_task:
                self._request_handler_task.cancel()
            self.write_error(ServiceUnavailable("Response Timeout"))

    def keep_alive_timeout_callback(self):
        """
        Check if elapsed time since last response exceeds our configured
        maximum keep alive timeout value and if so, close the transport
        pipe and let the response writer handle the error.

        :return: None
        """
        time_elapsed = time() - self._last_response_time
        if time_elapsed < self.keep_alive_timeout:
            time_left = self.keep_alive_timeout - time_elapsed
            self._keep_alive_timeout_handler = self.loop.call_later(
                time_left, self.keep_alive_timeout_callback
            )
        else:
            logger.debug("KeepAlive Timeout. Closing connection.")
            self.transport.close()
            self.transport = None

    # -------------------------------------------- #
    # Parsing
    # -------------------------------------------- #

    def data_received(self, data):
        # Check for the request itself getting too large and exceeding
        # memory limits
        self._total_request_size += len(data)
        if self._total_request_size > self.request_max_size:
            self.write_error(PayloadTooLarge("Payload Too Large"))

        # Create parser if this is the first time we're receiving data
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = HttpRequestParser(self)

        # requests count
        self.state["requests_count"] = self.state["requests_count"] + 1

        # Parse request chunk or close connection
        try:
            self.parser.feed_data(data)
        except HttpParserError:
            message = "Bad Request"
            if self._debug:
                message += "\n" + traceback.format_exc()
            self.write_error(InvalidUsage(message))

    def on_url(self, url):
        if not self.url:
            self.url = url
        else:
            self.url += url

    def on_header(self, name, value):
        self._header_fragment += name

        if value is not None:
            if (
                self._header_fragment == b"Content-Length"
                and int(value) > self.request_max_size
            ):
                self.write_error(PayloadTooLarge("Payload Too Large"))
            try:
                value = value.decode()
            except UnicodeDecodeError:
                value = value.decode("latin_1")
            self.headers.append(
                (self._header_fragment.decode().casefold(), value)
            )

            self._header_fragment = b""

    def on_headers_complete(self):
        self.request = self.request_class(
            url_bytes=self.url,
            headers=CIMultiDict(self.headers),
            version=self.parser.get_http_version(),
            method=self.parser.get_method().decode(),
            transport=self.transport,
            app=self.app,
        )
        # Remove any existing KeepAlive handler here,
        # It will be recreated if required on the new request.
        if self._keep_alive_timeout_handler:
            self._keep_alive_timeout_handler.cancel()
            self._keep_alive_timeout_handler = None
        if self.is_request_stream:
            self._is_stream_handler = self.router.is_stream_handler(
                self.request
            )
            if self._is_stream_handler:
                self.request.stream = StreamBuffer(
                    self.request_buffer_queue_size
                )
                self.execute_request_handler()

    def on_body(self, body):
        if self.is_request_stream and self._is_stream_handler:
            self._request_stream_task = self.loop.create_task(
                self.body_append(body)
            )
        else:
            self.request.body_push(body)

    async def body_append(self, body):
        if self.request.stream.is_full():
            self.transport.pause_reading()
            await self.request.stream.put(body)
            self.transport.resume_reading()
        else:
            await self.request.stream.put(body)

    def on_message_complete(self):
        # Entire request (headers and whole body) is received.
        # We can cancel and remove the request timeout handler now.
        if self._request_timeout_handler:
            self._request_timeout_handler.cancel()
            self._request_timeout_handler = None
        if self.is_request_stream and self._is_stream_handler:
            self._request_stream_task = self.loop.create_task(
                self.request.stream.put(None)
            )
            return
        self.request.body_finish()
        self.execute_request_handler()

    def execute_request_handler(self):
        """
        Invoke the request handler defined by the
        :func:`sanic.app.Sanic.handle_request` method

        :return: None
        """
        self._response_timeout_handler = self.loop.call_later(
            self.response_timeout, self.response_timeout_callback
        )
        self._last_request_time = time()
        self._request_handler_task = self.loop.create_task(
            self.request_handler(
                self.request, self.write_response, self.stream_response
            )
        )

    # -------------------------------------------- #
    # Responding
    # -------------------------------------------- #
    def log_response(self, response):
        """
        Helper method provided to enable the logging of responses in case if
        the :attr:`HttpProtocol.access_log` is enabled.

        :param response: Response generated for the current request

        :type response: :class:`sanic.response.HTTPResponse` or
            :class:`sanic.response.StreamingHTTPResponse`

        :return: None
        """
        if self.access_log:
            extra = {"status": getattr(response, "status", 0)}

            if isinstance(response, HTTPResponse):
                extra["byte"] = len(response.body)
            else:
                extra["byte"] = -1

            extra["host"] = "UNKNOWN"
            if self.request is not None:
                if self.request.ip:
                    extra["host"] = "{0}:{1}".format(
                        self.request.ip, self.request.port
                    )

                extra["request"] = "{0} {1}".format(
                    self.request.method, self.request.url
                )
            else:
                extra["request"] = "nil"

            access_logger.info("", extra=extra)

    def write_response(self, response):
        """
        Writes response content synchronously to the transport.
        """
        if self._response_timeout_handler:
            self._response_timeout_handler.cancel()
            self._response_timeout_handler = None
        try:
            keep_alive = self.keep_alive
            self.transport.write(
                response.output(
                    self.request.version, keep_alive, self.keep_alive_timeout
                )
            )
            self.log_response(response)
        except AttributeError:
            logger.error(
                "Invalid response object for url %s, "
                "Expected Type: HTTPResponse, Actual Type: %s",
                self.url,
                type(response),
            )
            self.write_error(ServerError("Invalid response type"))
        except RuntimeError:
            if self._debug:
                logger.error(
                    "Connection lost before response written @ %s",
                    self.request.ip,
                )
            keep_alive = False
        except Exception as e:
            self.bail_out(
                "Writing response failed, connection closed {}".format(repr(e))
            )
        finally:
            if not keep_alive:
                self.transport.close()
                self.transport = None
            else:
                self._keep_alive_timeout_handler = self.loop.call_later(
                    self.keep_alive_timeout, self.keep_alive_timeout_callback
                )
                self._last_response_time = time()
                self.cleanup()

    async def drain(self):
        await self._not_paused.wait()

    def push_data(self, data):
        self.transport.write(data)

    async def stream_response(self, response):
        """
        Streams a response to the client asynchronously. Attaches
        the transport to the response so the response consumer can
        write to the response as needed.
        """
        if self._response_timeout_handler:
            self._response_timeout_handler.cancel()
            self._response_timeout_handler = None

        try:
            keep_alive = self.keep_alive
            response.protocol = self
            await response.stream(
                self.request.version, keep_alive, self.keep_alive_timeout
            )
            self.log_response(response)
        except AttributeError:
            logger.error(
                "Invalid response object for url %s, "
                "Expected Type: HTTPResponse, Actual Type: %s",
                self.url,
                type(response),
            )
            self.write_error(ServerError("Invalid response type"))
        except RuntimeError:
            if self._debug:
                logger.error(
                    "Connection lost before response written @ %s",
                    self.request.ip,
                )
            keep_alive = False
        except Exception as e:
            self.bail_out(
                "Writing response failed, connection closed {}".format(repr(e))
            )
        finally:
            if not keep_alive:
                self.transport.close()
                self.transport = None
            else:
                self._keep_alive_timeout_handler = self.loop.call_later(
                    self.keep_alive_timeout, self.keep_alive_timeout_callback
                )
                self._last_response_time = time()
                self.cleanup()

    def write_error(self, exception):
        # An error _is_ a response.
        # Don't throw a response timeout, when a response _is_ given.
        if self._response_timeout_handler:
            self._response_timeout_handler.cancel()
            self._response_timeout_handler = None
        response = None
        try:
            response = self.error_handler.response(self.request, exception)
            version = self.request.version if self.request else "1.1"
            self.transport.write(response.output(version))
        except RuntimeError:
            if self._debug:
                logger.error(
                    "Connection lost before error written @ %s",
                    self.request.ip if self.request else "Unknown",
                )
        except Exception as e:
            self.bail_out(
                "Writing error failed, connection closed {}".format(repr(e)),
                from_error=True,
            )
        finally:
            if self.parser and (
                self.keep_alive or getattr(response, "status", 0) == 408
            ):
                self.log_response(response)
            try:
                self.transport.close()
            except AttributeError:
                logger.debug("Connection lost before server could close it.")

    def bail_out(self, message, from_error=False):
        """
        In case if the transport pipes are closed and the sanic app encounters
        an error while writing data to the transport pipe, we log the error
        with proper details.

        :param message: Error message to display
        :param from_error: If the bail out was invoked while handling an
            exception scenario.

        :type message: str
        :type from_error: bool

        :return: None
        """
        if from_error or self.transport is None or self.transport.is_closing():
            logger.error(
                "Transport closed @ %s and exception "
                "experienced during error handling",
                (
                    self.transport.get_extra_info("peername")
                    if self.transport is not None
                    else "N/A"
                ),
            )
            logger.debug("Exception:", exc_info=True)
        else:
            self.write_error(ServerError(message))
            logger.error(message)

    def cleanup(self):
        """This is called when KeepAlive feature is used,
        it resets the connection in order for it to be able
        to handle receiving another request on the same connection."""
        self.parser = None
        self.request = None
        self.url = None
        self.headers = None
        self._request_handler_task = None
        self._request_stream_task = None
        self._total_request_size = 0
        self._is_stream_handler = False

    def close_if_idle(self):
        """Close the connection if a request is not being sent or received

        :return: boolean - True if closed, false if staying open
        """
        if not self.parser:
            self.transport.close()
            return True
        return False

    def close(self):
        """
        Force close the connection.
        """
        if self.transport is not None:
            self.transport.close()
            self.transport = None
Ejemplo n.º 35
0
class HttpProtocol(asyncio.Protocol):
    __slots__ = (
        # event loop, connection
        'loop', 'transport', 'connections', 'signal',
        # request params
        'parser', 'request', 'url', 'headers',
        # request config
        'request_handler', 'request_timeout', 'request_max_size',
        # connection management
        '_total_request_size', '_timeout_handler', '_last_communication_time')

    def __init__(self, *, loop, request_handler, error_handler,
                 signal=Signal(), connections=set(), request_timeout=60,
                 request_max_size=None):
        self.loop = loop
        self.transport = None
        self.request = None
        self.parser = None
        self.url = None
        self.headers = None
        self.signal = signal
        self.connections = connections
        self.request_handler = request_handler
        self.error_handler = error_handler
        self.request_timeout = request_timeout
        self.request_max_size = request_max_size
        self._total_request_size = 0
        self._timeout_handler = None
        self._last_request_time = None
        self._request_handler_task = None

    # -------------------------------------------- #
    # Connection
    # -------------------------------------------- #

    def connection_made(self, transport):
        self.connections.add(self)
        self._timeout_handler = self.loop.call_later(
            self.request_timeout, self.connection_timeout)
        self.transport = transport
        self._last_request_time = current_time

    def connection_lost(self, exc):
        self.connections.discard(self)
        self._timeout_handler.cancel()

    def connection_timeout(self):
        # Check if
        time_elapsed = current_time - self._last_request_time
        if time_elapsed < self.request_timeout:
            time_left = self.request_timeout - time_elapsed
            self._timeout_handler = (
                self.loop.call_later(time_left, self.connection_timeout))
        else:
            if self._request_handler_task:
                self._request_handler_task.cancel()
            exception = RequestTimeout('Request Timeout')
            self.write_error(exception)

    # -------------------------------------------- #
    # Parsing
    # -------------------------------------------- #

    def data_received(self, data):
        # Check for the request itself getting too large and exceeding
        # memory limits
        self._total_request_size += len(data)
        if self._total_request_size > self.request_max_size:
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)

        # Create parser if this is the first time we're receiving data
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = HttpRequestParser(self)

        # Parse request chunk or close connection
        try:
            self.parser.feed_data(data)
        except HttpParserError:
            exception = InvalidUsage('Bad Request')
            self.write_error(exception)

    def on_url(self, url):
        self.url = url

    def on_header(self, name, value):
        if name == b'Content-Length' and int(value) > self.request_max_size:
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)

        self.headers.append((name.decode().casefold(), value.decode()))

    def on_headers_complete(self):
        self.request = Request(
            url_bytes=self.url,
            headers=CIDict(self.headers),
            version=self.parser.get_http_version(),
            method=self.parser.get_method().decode(),
            transport=self.transport
        )

    def on_body(self, body):
        self.request.body.append(body)

    def on_message_complete(self):
        if self.request.body:
            self.request.body = b''.join(self.request.body)
        self._request_handler_task = self.loop.create_task(
            self.request_handler(self.request, self.write_response))

    # -------------------------------------------- #
    # Responding
    # -------------------------------------------- #

    def write_response(self, response):
        try:
            keep_alive = (
                self.parser.should_keep_alive() and not self.signal.stopped)
            self.transport.write(
                response.output(
                    self.request.version, keep_alive, self.request_timeout))
        except RuntimeError:
            log.error(
                'Connection lost before response written @ {}'.format(
                    self.request.ip))
        except Exception as e:
            self.bail_out(
                "Writing response failed, connection closed {}".format(e))
        finally:
            if not keep_alive:
                self.transport.close()
            else:
                # Record that we received data
                self._last_request_time = current_time
                self.cleanup()

    def write_error(self, exception):
        try:
            response = self.error_handler.response(self.request, exception)
            version = self.request.version if self.request else '1.1'
            self.transport.write(response.output(version))
        except RuntimeError:
            log.error(
                'Connection lost before error written @ {}'.format(
                    self.request.ip))
        except Exception as e:
            self.bail_out(
                "Writing error failed, connection closed {}".format(e),
                from_error=True)
        finally:
            self.transport.close()

    def bail_out(self, message, from_error=False):
        if from_error and self.transport.is_closing():
            log.error(
                ("Transport closed @ {} and exception "
                 "experienced during error handling").format(
                    self.transport.get_extra_info('peername')))
            log.debug(
                'Exception:\n{}'.format(traceback.format_exc()))
        else:
            exception = ServerError(message)
            self.write_error(exception)
            log.error(message)

    def cleanup(self):
        self.parser = None
        self.request = None
        self.url = None
        self.headers = None
        self._request_handler_task = None
        self._total_request_size = 0

    def close_if_idle(self):
        """
        Close the connection if a request is not being sent or received
        :return: boolean - True if closed, false if staying open
        """
        if not self.parser:
            self.transport.close()
            return True
        return False
Ejemplo n.º 36
0
class HttpProtocol(asyncio.Protocol):
    __slots__ = (
        # event loop, connection
        'loop', 'transport', 'connections', 'signal',
        # request params
        'parser', 'request', 'url', 'headers',
        # request config
        'request_handler', 'request_timeout', 'request_max_size',
        'request_class', 'is_request_stream', 'router',
        # enable or disable access log / error log purpose
        'has_log',
        # connection management
        '_total_request_size', '_timeout_handler', '_last_communication_time',
        '_is_stream_handler')

    def __init__(self, *, loop, request_handler, error_handler,
                 signal=Signal(), connections=set(), request_timeout=60,
                 request_max_size=None, request_class=None, has_log=True,
                 keep_alive=True, is_request_stream=False, router=None,
                 state=None, debug=False, **kwargs):
        self.loop = loop
        self.transport = None
        self.request = None
        self.parser = None
        self.url = None
        self.headers = None
        self.router = router
        self.signal = signal
        self.has_log = has_log
        self.connections = connections
        self.request_handler = request_handler
        self.error_handler = error_handler
        self.request_timeout = request_timeout
        self.request_max_size = request_max_size
        self.request_class = request_class or Request
        self.is_request_stream = is_request_stream
        self._is_stream_handler = False
        self._total_request_size = 0
        self._timeout_handler = None
        self._last_request_time = None
        self._request_handler_task = None
        self._request_stream_task = None
        self._keep_alive = keep_alive
        self._header_fragment = b''
        self.state = state if state else {}
        if 'requests_count' not in self.state:
            self.state['requests_count'] = 0
        self._debug = debug

    @property
    def keep_alive(self):
        return (
            self._keep_alive and
            not self.signal.stopped and
            self.parser.should_keep_alive())

    # -------------------------------------------- #
    # Connection
    # -------------------------------------------- #

    def connection_made(self, transport):
        self.connections.add(self)
        self._timeout_handler = self.loop.call_later(
            self.request_timeout, self.connection_timeout)
        self.transport = transport
        self._last_request_time = current_time

    def connection_lost(self, exc):
        self.connections.discard(self)
        self._timeout_handler.cancel()

    def connection_timeout(self):
        # Check if
        time_elapsed = current_time - self._last_request_time
        if time_elapsed < self.request_timeout:
            time_left = self.request_timeout - time_elapsed
            self._timeout_handler = (
                self.loop.call_later(time_left, self.connection_timeout))
        else:
            if self._request_stream_task:
                self._request_stream_task.cancel()
            if self._request_handler_task:
                self._request_handler_task.cancel()
            try:
                raise RequestTimeout('Request Timeout')
            except RequestTimeout as exception:
                self.write_error(exception)

    # -------------------------------------------- #
    # Parsing
    # -------------------------------------------- #

    def data_received(self, data):
        # Check for the request itself getting too large and exceeding
        # memory limits
        self._total_request_size += len(data)
        if self._total_request_size > self.request_max_size:
            exception = PayloadTooLarge('Payload Too Large')
            self.write_error(exception)

        # Create parser if this is the first time we're receiving data
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = HttpRequestParser(self)

        # requests count
        self.state['requests_count'] = self.state['requests_count'] + 1

        # Parse request chunk or close connection
        try:
            self.parser.feed_data(data)
        except HttpParserError:
            message = 'Bad Request'
            if self._debug:
                message += '\n' + traceback.format_exc()
            exception = InvalidUsage(message)
            self.write_error(exception)

    def on_url(self, url):
        if not self.url:
            self.url = url
        else:
            self.url += url

    def on_header(self, name, value):
        self._header_fragment += name

        if value is not None:
            if self._header_fragment == b'Content-Length' \
                    and int(value) > self.request_max_size:
                exception = PayloadTooLarge('Payload Too Large')
                self.write_error(exception)

            self.headers.append(
                    (self._header_fragment.decode().casefold(),
                     value.decode()))

            self._header_fragment = b''

    def on_headers_complete(self):
        self.request = self.request_class(
            url_bytes=self.url,
            headers=CIDict(self.headers),
            version=self.parser.get_http_version(),
            method=self.parser.get_method().decode(),
            transport=self.transport
        )
        if self.is_request_stream:
            self._is_stream_handler = self.router.is_stream_handler(
                self.request)
            if self._is_stream_handler:
                self.request.stream = asyncio.Queue()
                self.execute_request_handler()

    def on_body(self, body):
        if self.is_request_stream and self._is_stream_handler:
            self._request_stream_task = self.loop.create_task(
                self.request.stream.put(body))
            return
        self.request.body.append(body)

    def on_message_complete(self):
        if self.is_request_stream and self._is_stream_handler:
            self._request_stream_task = self.loop.create_task(
                self.request.stream.put(None))
            return
        self.request.body = b''.join(self.request.body)
        self.execute_request_handler()

    def execute_request_handler(self):
        self._request_handler_task = self.loop.create_task(
            self.request_handler(
                self.request,
                self.write_response,
                self.stream_response))

    # -------------------------------------------- #
    # Responding
    # -------------------------------------------- #
    def write_response(self, response):
        """
        Writes response content synchronously to the transport.
        """
        try:
            keep_alive = self.keep_alive
            self.transport.write(
                response.output(
                    self.request.version, keep_alive,
                    self.request_timeout))
            if self.has_log:
                netlog.info('', extra={
                    'status': response.status,
                    'byte': len(response.body),
                    'host': '{0}:{1}'.format(self.request.ip[0],
                                             self.request.ip[1]),
                    'request': '{0} {1}'.format(self.request.method,
                                                self.request.url)
                })
        except AttributeError:
            log.error(
                ('Invalid response object for url {}, '
                 'Expected Type: HTTPResponse, Actual Type: {}').format(
                    self.url, type(response)))
            self.write_error(ServerError('Invalid response type'))
        except RuntimeError:
            log.error(
                'Connection lost before response written @ {}'.format(
                    self.request.ip))
        except Exception as e:
            self.bail_out(
                "Writing response failed, connection closed {}".format(
                    repr(e)))
        finally:
            if not keep_alive:
                self.transport.close()
            else:
                self._last_request_time = current_time
                self.cleanup()

    async def stream_response(self, response):
        """
        Streams a response to the client asynchronously. Attaches
        the transport to the response so the response consumer can
        write to the response as needed.
        """

        try:
            keep_alive = self.keep_alive
            response.transport = self.transport
            await response.stream(
                self.request.version, keep_alive, self.request_timeout)
            if self.has_log:
                netlog.info('', extra={
                    'status': response.status,
                    'byte': -1,
                    'host': '{0}:{1}'.format(self.request.ip[0],
                                             self.request.ip[1]),
                    'request': '{0} {1}'.format(self.request.method,
                                                self.request.url)
                })
        except AttributeError:
            log.error(
                ('Invalid response object for url {}, '
                 'Expected Type: HTTPResponse, Actual Type: {}').format(
                    self.url, type(response)))
            self.write_error(ServerError('Invalid response type'))
        except RuntimeError:
            log.error(
                'Connection lost before response written @ {}'.format(
                    self.request.ip))
        except Exception as e:
            self.bail_out(
                "Writing response failed, connection closed {}".format(
                    repr(e)))
        finally:
            if not keep_alive:
                self.transport.close()
            else:
                self._last_request_time = current_time
                self.cleanup()

    def write_error(self, exception):
        response = None
        try:
            response = self.error_handler.response(self.request, exception)
            version = self.request.version if self.request else '1.1'
            self.transport.write(response.output(version))
        except RuntimeError:
            log.error(
                'Connection lost before error written @ {}'.format(
                    self.request.ip if self.request else 'Unknown'))
        except Exception as e:
            self.bail_out(
                "Writing error failed, connection closed {}".format(repr(e)),
                from_error=True)
        finally:
            if self.has_log:
                extra = dict()
                if isinstance(response, HTTPResponse):
                    extra['status'] = response.status
                    extra['byte'] = len(response.body)
                else:
                    extra['status'] = 0
                    extra['byte'] = -1
                if self.request:
                    extra['host'] = '%s:%d' % self.request.ip,
                    extra['request'] = '%s %s' % (self.request.method,
                                                  self.url)
                else:
                    extra['host'] = 'UNKNOWN'
                    extra['request'] = 'nil'
                if self.parser and not (self.keep_alive
                                        and extra['status'] == 408):
                    netlog.info('', extra=extra)
            self.transport.close()

    def bail_out(self, message, from_error=False):
        if from_error or self.transport.is_closing():
            log.error(
                ("Transport closed @ {} and exception "
                 "experienced during error handling").format(
                    self.transport.get_extra_info('peername')))
            log.debug(
                'Exception:\n{}'.format(traceback.format_exc()))
        else:
            exception = ServerError(message)
            self.write_error(exception)
            log.error(message)

    def cleanup(self):
        self.parser = None
        self.request = None
        self.url = None
        self.headers = None
        self._request_handler_task = None
        self._request_stream_task = None
        self._total_request_size = 0
        self._is_stream_handler = False

    def close_if_idle(self):
        """Close the connection if a request is not being sent or received

        :return: boolean - True if closed, false if staying open
        """
        if not self.parser:
            self.transport.close()
            return True
        return False

    def close(self):
        """
        Force close the connection.
        """
        if self.transport is not None:
            self.transport.close()
            self.transport = None