Exemple #1
0
    def test_parser_request_chunked_cb_error_1(self):
        class Error(Exception):
            pass

        m = mock.Mock()
        m.on_chunk_header.side_effect = Error()

        p = httptools.HttpRequestParser(m)
        try:
            p.feed_data(CHUNKED_REQUEST1_1)
        except httptools.HttpParserCallbackError as ex:
            self.assertIsInstance(ex.__context__, Error)
        else:
            self.fail('HttpParserCallbackError was not raised')
    def test_parser_request_error_in_cb_on_headers_complete(self):
        class Error(Exception):
            pass

        m = mock.Mock()
        m.on_headers_complete.side_effect = Error()
        p = httptools.HttpRequestParser(m)

        try:
            p.feed_data(UPGRADE_REQUEST1)
        except httptools.HttpParserCallbackError as ex:
            self.assertIsInstance(ex.__context__, Error)
        else:
            self.fail('HttpParserCallbackError was not raised')
Exemple #3
0
    async def _wait(self):
        """
        The main core of the protocol.

        This constructs a new Werkzeug request from the headers.
        """
        # Check if the body has data in it by asking it to tell us what position it's seeked to.
        # If it's > 0, it has data, so we can use it. Otherwise, it doesn't, so it's useless.
        told = self.body.tell()
        if told:
            self.logger.debug("Read {} bytes of body data from the connection".format(told))
            self.body.seek(0)
            body = self.body
        else:
            body = None

        version = self.parser.get_http_version()
        method = self.parser.get_method().decode()

        new_environ = to_wsgi_environment(headers=self.headers, method=method, path=self.full_url,
                                          http_version=version, body=body)

        new_environ["kyoukai.protocol"] = self
        new_environ["SERVER_NAME"] = self.component.get_server_name()
        new_environ["SERVER_PORT"] = str(self.server_port)
        new_environ["REMOTE_ADDR"] = self.ip
        new_environ["REMOTE_PORT"] = self.client_port

        # Construct a Request object.
        new_r = self.app.request_class(new_environ, False)

        # Invoke the app.
        async with self.lock:
            try:
                result = await self.app.process_request(new_r, self.parent_context)
            except Exception:
                # not good!
                # write the scary exception text
                self.logger.exception("Error in Kyoukai request handling!")
                self._raw_write(CRITICAL_ERROR_TEXT.encode("utf-8"))
                return
            else:
                # Write the response.
                self.write_response(result, new_environ)
            finally:
                if not self.parser.should_keep_alive():
                    self.close()
                # unlock the event and remove the waiter
                self.parser = httptools.HttpRequestParser(self)
Exemple #4
0
 def data_received(self, data):
     self.cancel_timeout_keep_alive_task()
     try:
         if self.parser is None:
             self.headers = []
             self.parser = httptools.HttpRequestParser(self)
         self.parser.feed_data(data)
     except httptools.parser.errors.HttpParserError as exc:
         msg = "Invalid HTTP request received."
         if self.debug:
             msg += "\n" + traceback.format_exc()
         self.logger.error(msg)
         self.on_response(msg)
     except httptools.HttpParserUpgrade as exc:
         #self.handle_upgrade()
         pass
Exemple #5
0
    def test_parser_request_chunked_3(self):
        m = mock.Mock()
        p = httptools.HttpRequestParser(m)

        p.feed_data(CHUNKED_REQUEST1_3)

        self.assertEqual(p.get_method(), b'POST')

        m.on_url.assert_called_once_with(b'/test.php?a=b+c')
        self.assertEqual(p.get_http_version(), '1.2')

        m.on_header.assert_called_with(b'Transfer-Encoding', b'chunked')
        m.on_chunk_header.assert_called_with()
        m.on_chunk_complete.assert_called_with()

        self.assertTrue(m.on_message_complete.called)
Exemple #6
0
    def __init__(
        self,
        app,
        loop=None,
        connections=None,
        tasks=None,
        state=None,
        logger=None,
        ws_protocol_class=None,
        proxy_headers=False,
        root_path="",
        limit_concurrency=None,
        timeout_keep_alive=5,
        timeout_response=60,
    ):
        self.app = app
        self.loop = loop or asyncio.get_event_loop()
        self.connections = set() if connections is None else connections
        self.tasks = set() if tasks is None else tasks
        self.state = {"total_requests": 0} if state is None else state
        self.logger = logger or logging.getLogger()
        self.parser = httptools.HttpRequestParser(self)
        self.ws_protocol_class = ws_protocol_class
        self.proxy_headers = proxy_headers
        self.root_path = root_path
        self.limit_concurrency = limit_concurrency

        # Timeouts
        self.timeout_keep_alive_task = None
        self.timeout_keep_alive = timeout_keep_alive
        self.timeout_response = timeout_response

        # Per-connection state
        self.transport = None
        self.flow = None
        self.server = None
        self.client = None
        self.scheme = None
        self.pipeline = []

        # Per-request state
        self.url = None
        self.scope = None
        self.headers = None
        self.expect_100_continue = False
        self.cycle = None
        self.message_event = asyncio.Event()
Exemple #7
0
    def data_received(self, data):
        # Check for the request itself getting too large and exceeding memory limits
        self._total_request_size += len(data)
        if self._total_request_size > self.request_max_size:
            return self.bail_out("Request too large ({}), connection closed".format(self._total_request_size))

        # Create parser if this is the first time we're receiving data
        if self.parser is None:
            assert self.request is None
            self.headers = []
            self.parser = httptools.HttpRequestParser(self)

        # Parse request chunk or close connection
        try:
            self.parser.feed_data(data)
        except httptools.parser.errors.HttpParserError as e:
            self.bail_out("Invalid request data, connection closed ({})".format(e))
Exemple #8
0
    def __init__(self,
                 config,
                 server_state,
                 on_connection_lost: Callable = None,
                 _loop=None):
        if not config.loaded:
            config.load()

        self.config = config
        self.app = config.loaded_app
        self.on_connection_lost = on_connection_lost
        self.loop = _loop or asyncio.get_event_loop()
        self.logger = logging.getLogger("uvicorn.error")
        self.access_logger = logging.getLogger("uvicorn.access")
        self.access_log = self.access_logger.hasHandlers()
        self.access_log_format = config.access_log_format
        self.parser = httptools.HttpRequestParser(self)
        self.ws_protocol_class = config.ws_protocol_class
        self.root_path = config.root_path
        self.limit_concurrency = config.limit_concurrency

        # Timeouts
        self.timeout_keep_alive_task = None
        self.timeout_keep_alive = config.timeout_keep_alive

        # Global state
        self.server_state = server_state
        self.connections = server_state.connections
        self.tasks = server_state.tasks
        self.default_headers = server_state.default_headers

        # Per-connection state
        self.transport = None
        self.flow = None
        self.server = None
        self.client = None
        self.scheme = None
        self.pipeline = []

        # Per-request state
        self.url = None
        self.scope = None
        self.headers = None
        self.expect_100_continue = False
        self.cycle = None
        self.request_start_time = None
Exemple #9
0
 async def _wait_wrapper(self):
     try:
         if hasattr(self, "_wait"):
             await self._wait()
         else:
             return
     except:
         self.logger.critical("Error in Kyoukai's HTTP handling!", exc_info=True)
         self._raw_write(CRITICAL_ERROR_TEXT.encode())
         self.close()
     finally:
         # we might have change protocol by now.
         # if so, don't try and cancel the non-existant thing.
         if hasattr(self, "waiter"):
             self.waiter.cancel()
             self.waiter = None
             self.parser = httptools.HttpRequestParser(self)
Exemple #10
0
    def __init__(
        self,
        cli_sock,
        cli_addr,
        header_timeout,
        body_timeout,
        keep_alive_timeout,
        max_header_size,
        max_body_size,
        header_buffer_size,
        body_buffer_size,
    ):
        self.cli_sock = cli_sock
        self.cli_addr = cli_addr
        self.header_timeout = header_timeout
        self.body_timeout = body_timeout
        self.keep_alive_timeout = keep_alive_timeout
        self.max_header_size = max_header_size
        self.max_body_size = max_body_size
        self.max_request_size = max_header_size + max_body_size
        self.header_buffer_size = header_buffer_size
        self.body_buffer_size = body_buffer_size

        # public attrs
        self.method = None
        self.url = None
        self.version = None
        self.headers = []
        self.remote_ip = None
        self.protocol = None
        self.keep_alive = False

        # helper attrs
        self._address = '{}:{}'.format(*self.cli_addr)

        # temp attrs
        self._parser = httptools.HttpRequestParser(self)
        self._buffer_size = self.header_buffer_size
        self._started = False
        self._headers_completed = False
        self._completed = False
        self._url = b''
        self._header_name = b''
        self._body_chunks = []
        self._readed_size = 0
Exemple #11
0
    def test_parser_request_fragmented_bytes(self):
        m = mock.Mock()
        headers = {}
        m.on_header.side_effect = headers.__setitem__
        p = httptools.HttpRequestParser(m)

        REQUEST = \
            b'PUT / HTTP/1.1\r\nHost: localhost:1234\r\nContent-' \
            b'Type: text/plain; charset=utf-8\r\n\r\n'

        step = 1
        for i in range(0, len(REQUEST), step):
            p.feed_data(REQUEST[i:i+step])

        self.assertEqual(
            headers,
            {b'Host': b'localhost:1234',
             b'Content-Type': b'text/plain; charset=utf-8'})
Exemple #12
0
    def __init__(self, component, parent_context: Context,
                 server_ip: str, server_port: int):
        """
        :param component: The :class:`kyoukai.asphalt.KyoukaiComponent` associated with this request.
        :param parent_context: The parent context for this request.
            A new HTTPRequestContext will be derived from this.
        """

        self.component = component
        self.app = component.app
        self.parent_context = parent_context

        self.server_ip = server_ip
        self.server_port = server_port

        # Transport.
        # This is written to by our request when it's done.
        self.transport = None  # type: asyncio.WriteTransport

        # Request lock.
        # This ensures that requests are processed serially, and responded to in the correct order,
        # as the lock is released after processing a request completely.
        self.lock = asyncio.Lock()

        # The parser itself.
        # This is created per connection, and uses our own class.
        self.parser = httptools.HttpRequestParser(self)

        # A waiter that 'waits' on the event to clear.
        # Once the wait is over, it then delegates the request to the app.
        self.waiter = None  # type: asyncio.Task

        # The IP and port of the client.
        self.ip, self.client_port = None, None

        # Intermediary data storage.
        # This is a list because headers are appended as (Name, Value) pairs.
        # In HTTP/1.1, there can be multiple headers with the same name but different values.
        self.headers = []
        self.body = BytesIO()
        self.full_url = ""

        self.loop = self.app.loop
        self.logger = logging.getLogger("Kyoukai.HTTP11")
Exemple #13
0
    def __init__(
        self,
        config: Config,
        server_state: ServerState,
        _loop: Optional[asyncio.AbstractEventLoop] = None,
    ) -> None:
        if not config.loaded:
            config.load()

        self.config = config
        self.app = config.loaded_app
        self.loop = _loop or asyncio.get_event_loop()
        self.logger = logging.getLogger("uvicorn.error")
        self.access_logger = logging.getLogger("uvicorn.access")
        self.access_log = self.access_logger.hasHandlers()
        self.parser = httptools.HttpRequestParser(self)
        self.ws_protocol_class = config.ws_protocol_class
        self.root_path = config.root_path
        self.limit_concurrency = config.limit_concurrency

        # Timeouts
        self.timeout_keep_alive_task: Optional[TimerHandle] = None
        self.timeout_keep_alive = config.timeout_keep_alive

        # Global state
        self.server_state = server_state
        self.connections = server_state.connections
        self.tasks = server_state.tasks
        self.default_headers = server_state.default_headers

        # Per-connection state
        self.transport: asyncio.Transport = None  # type: ignore[assignment]
        self.flow: FlowControl = None  # type: ignore[assignment]
        self.server: Optional[Tuple[str, int]] = None
        self.client: Optional[Tuple[str, int]] = None
        self.scheme: Optional[Literal["http", "https"]] = None
        self.pipeline: Deque[Tuple[RequestResponseCycle, ASGI3Application]] = deque()

        # Per-request state
        self.scope: HTTPScope = None  # type: ignore[assignment]
        self.headers: List[Tuple[bytes, bytes]] = None  # type: ignore[assignment]
        self.expect_100_continue = False
        self.cycle: RequestResponseCycle = None  # type: ignore[assignment]
Exemple #14
0
    def __init__(self, config, server_state, _loop=None):
        if not config.loaded:
            config.load()

        self.config = config
        self.app = config.loaded_app
        self.loop = _loop or asyncio.get_event_loop()
        self.logger = config.logger_instance
        self.access_log = config.access_log and (self.logger.level <=
                                                 logging.INFO)
        self.parser = httptools.HttpRequestParser(self)
        self.ws_protocol_class = config.ws_protocol_class
        self.root_path = config.root_path
        self.limit_concurrency = config.limit_concurrency

        # Timeouts
        self.timeout_keep_alive_task = None
        self.timeout_keep_alive = config.timeout_keep_alive

        # Global state
        self.server_state = server_state
        self.connections = server_state.connections
        self.tasks = server_state.tasks
        self.default_headers = server_state.default_headers

        # Per-connection state
        self.transport = None
        self.flow = None
        self.server = None
        self.client = None
        self.scheme = None
        self.pipeline = []

        # Per-request state
        self.url = None
        self.scope = None
        self.headers = None
        self.expect_100_continue = False
        self.cycle = None
        self.message_event = asyncio.Event()
Exemple #15
0
    async def read_http_message(
            self, reader: asyncio.streams.StreamReader) -> Request:
        """
        this funciton will reading data cyclically
            until recivied a complete http message
        :param reqreaderuest: the asyncio.streams.StreamReader instance
        :return The Request instance
        """
        protocol = ParseProtocol()
        parser = httptools.HttpRequestParser(protocol)
        while True:
            data = await reader.read(2 ** 16)

            try:
                parser.feed_data(data)
            except httptools.HttpParserUpgrade:
                raise HttpException(400)

            if protocol.completed:
                return Request.load_from_parser(parser, protocol)
            if data == b'':
                return None
Exemple #16
0
    def test_parser_request_fragmented_header(self):
        m = mock.Mock()
        headers = {}
        m.on_header.side_effect = headers.__setitem__
        p = httptools.HttpRequestParser(m)

        REQUEST = (
            b'PUT / HTTP/1.1\r\nHost: localhost:1234\r\nContent-',
            b'Type: text/plain; charset=utf-8\r\n\r\n',
        )

        p.feed_data(REQUEST[0])

        m.on_message_begin.assert_called_once_with()
        m.on_url.assert_called_once_with(b'/')
        self.assertEqual(headers, {b'Host': b'localhost:1234'})

        p.feed_data(REQUEST[1])
        self.assertEqual(
            headers,
            {b'Host': b'localhost:1234',
             b'Content-Type': b'text/plain; charset=utf-8'})
Exemple #17
0
    def test_parser_request_chunked_2(self):
        m = mock.Mock()

        headers = {}
        m.on_header.side_effect = headers.__setitem__

        m.on_url = None
        m.on_body = None
        m.on_headers_complete = None
        m.on_chunk_header = None
        m.on_chunk_complete = None

        p = httptools.HttpRequestParser(m)
        p.feed_data(CHUNKED_REQUEST1_1)
        p.feed_data(CHUNKED_REQUEST1_2)

        self.assertEqual(
            headers,
            {b'User-Agent': b'spam',
             b'Transfer-Encoding': b'chunked',
             b'Host': b'bar',
             b'Vary': b'*'})
Exemple #18
0
    def __init__(self, app, config, server_state):
        self.app = app
        self.config = config
        self.loop = config.loop
        self.logger = config.logger
        self.access_log = config.access_log and (self.logger.level <= logging.INFO)
        self.parser = httptools.HttpRequestParser(self)
        self.ws_protocol = config.ws_protocol
        self.root_path = config.root_path
        self.limit_concurrency = config.limit_concurrency
        self.keep_alive_timeout = config.keep_alive_timeout

        # Timeouts
        self.timeout_keep_alive_task = None
        self.timeout_keep_alive = config.timeout_keep_alive

        # Global state
        self.server_state = server_state
        self.connections = server_state.connections
        self.tasks = server_state.tasks
        self.default_headers = server_state.default_headers + config.default_headers

        # Per-connection state
        self.transport = None
        self.server = None
        self.client = None
        self.scheme = None
        self.pipeline = []

        # Per-request state
        self.url = None
        self.environ = None
        self.body = b""
        self.more_body = True
        self.headers = []
        self.expect_100_continue = False
        self.message_event = asyncio.Event()
        self.message_event.set()
Exemple #19
0
    def __init__(self, config, global_state=None):
        self.config = config
        self.app = config.app
        self.loop = config.loop or asyncio.get_event_loop()
        self.logger = config.logger or logging.getLogger("uvicorn")
        self.access_log = config.access_log and (self.logger.level <=
                                                 logging.INFO)
        self.parser = httptools.HttpRequestParser(self)
        self.ws_protocol_class = config.ws_protocol_class
        self.root_path = config.root_path
        self.limit_concurrency = config.limit_concurrency

        # Timeouts
        self.timeout_keep_alive_task = None
        self.timeout_keep_alive = config.timeout_keep_alive

        # Global state
        if global_state is None:
            global_state = GlobalState()
        self.global_state = global_state
        self.connections = global_state.connections
        self.tasks = global_state.tasks

        # Per-connection state
        self.transport = None
        self.flow = None
        self.server = None
        self.client = None
        self.scheme = None
        self.pipeline = []

        # Per-request state
        self.url = None
        self.scope = None
        self.headers = None
        self.expect_100_continue = False
        self.cycle = None
        self.message_event = asyncio.Event()
Exemple #20
0
    async def _serve_client(self, reader, writer):
        """Serve a client.

        """

        protocol = Protocol()
        parser = httptools.HttpRequestParser(protocol)

        while not protocol.message_complete:
            parser.feed_data(await reader.readline())

        handler, params = self._unpack_url(protocol.url)
        response = Response()

        if handler is not None:
            method = parser.get_method()
            request = Request()

            if method == b'GET':
                await handler.get(request, response, *params)
            elif method == b'POST':
                await handler.post(request, response, *params)
            else:
                response.status = 405
        else:
            response.status = 404

        if response.status == 200:
            writer.write(('HTTP/1.0 200 OK\r\n'
                          f'Server: HTTPAsync/{__version__}\r\n'
                          'Date: Thu, 27 Jun 2019 05:56:10 GMT\r\n'
                          'Content-type: text/plain; charset=utf-8\r\n'
                          f'Content-Length: {len(response.data)}\r\n'
                          '\r\n').encode('utf-8') + response.data)

        writer.close()
    def test_parser_request_fragmented(self):
        m = mock.Mock()
        headers = {}
        m.on_header.side_effect = headers.__setitem__
        p = httptools.HttpRequestParser(m)

        REQUEST = (
            b'PUT / HTTP/1.1\r\nHost: localhost:1234\r\nContent-Type: text/pl',
            b'ain; charset=utf-8\r\nX-Empty-Header: \r\nConnection: close\r\n',
            b'Content-Length: 10\r\n\r\n1234567890',
        )

        p.feed_data(REQUEST[0])

        m.on_message_begin.assert_called_once_with()
        m.on_url.assert_called_once_with(b'/')
        self.assertEqual(headers, {b'Host': b'localhost:1234'})

        p.feed_data(REQUEST[1])
        self.assertEqual(
            headers, {
                b'Host': b'localhost:1234',
                b'Content-Type': b'text/plain; charset=utf-8',
                b'X-Empty-Header': b''
            })

        p.feed_data(REQUEST[2])
        self.assertEqual(
            headers, {
                b'Host': b'localhost:1234',
                b'Content-Type': b'text/plain; charset=utf-8',
                b'X-Empty-Header': b'',
                b'Connection': b'close',
                b'Content-Length': b'10'
            })
        m.on_message_complete.assert_called_once_with()
Exemple #22
0
 def test_parser_request_4(self):
     p = httptools.HttpRequestParser(None)
     with self.assertRaisesRegex(TypeError, 'a bytes-like object'):
         p.feed_data('POST  HTTP/1.2')
Exemple #23
0
 def test_parser_request_3(self):
     p = httptools.HttpRequestParser(None)
     with self.assertRaises(httptools.HttpParserInvalidURLError):
         p.feed_data(b'POST  HTTP/1.2')
Exemple #24
0
 def test_parser_request_2(self):
     p = httptools.HttpRequestParser(None)
     with self.assertRaises(httptools.HttpParserInvalidMethodError):
         p.feed_data(b'SPAM /test.php?a=b+c HTTP/1.2')
Exemple #25
0
 def __init__(self):
     self.parser = httptools.HttpRequestParser(self)
Exemple #26
0
    def connection_made(self, transport: asyncio.Transport) -> None:
        self.transport = transport
        self.parser = httptools.HttpRequestParser(self)
        self.reader = asyncio.StreamReader(loop=self.loop)

        self.start_timeout()
Exemple #27
0
 def __init__(self):
     self._transport = None
     self._parser = httptools.HttpRequestParser(self)
     self._request = Request()
Exemple #28
0
 def _build_parser(callbacks):
     return httptools.HttpRequestParser(callbacks)
 def data_received(self, data):
     self._parser = httptools.HttpRequestParser(self)
     self._parser.feed_data(data)
    async def interact(self, client):
        protocol = HTTPProxyProtocol()
        parser = httptools.HttpRequestParser(protocol)
        s = b""

        while protocol.need_proxy_data:
            data = await client.recv(65536)
            if not data:
                break
            s += data
            try:
                parser.feed_data(data)
            except httptools.HttpParserUpgrade as e:
                break

        version = parser.get_http_version()
        if version == "0.0":
            return
        if self.auth:
            pauth = protocol.headers_dict.get(b"Proxy-Authenticate", None)
            httpauth = b"Basic " + base64.b64encode(b":".join(self.auth))
            if httpauth != pauth:
                await client.sendall(
                    version.encode() +
                    b" 407 Proxy Authentication Required\r\n"
                    b"Connection: close\r\n"
                    b'Proxy-Authenticate: Basic realm="simple"\r\n\r\n')
                raise Exception("Unauthorized HTTP Required")

        method = parser.get_method()
        if method == b"CONNECT":
            host, _, port = protocol.url.partition(b":")
            self.taddr = (host.decode(), int(port))
        else:
            url = urllib.parse.urlparse(protocol.url)
            if not url.hostname:
                await client.sendall(b"HTTP/1.1 200 OK\r\n"
                                     b"Connection: close\r\n"
                                     b"Content-Type: text/plain\r\n"
                                     b"Content-Length: 2\r\n\r\n"
                                     b"ok")
                return
            self.taddr = (url.hostname.decode(), url.port or 80)
            newpath = url._replace(netloc=b"", scheme=b"").geturl()
        remote_conn = await self.connect_remote()
        async with remote_conn:
            if method == b"CONNECT":
                await client.sendall(
                    b"HTTP/1.1 200 Connection: Established\r\n\r\n")
            else:
                header_lines = "\r\n".join(f"{k}: {v}"
                                           for k, v in protocol.headers
                                           if not k[:6] == "Proxy-")
                header_lines = header_lines.encode()
                remote_req_headers = b"%s %s HTTP/%s\r\n%s\r\n\r\n" % (
                    method,
                    newpath,
                    version.encode(),
                    header_lines,
                )
                await remote_conn.sendall(remote_req_headers)
            if protocol.buffer:
                await client.sendall(protocol.buffer)
            await self.relay(client, remote_conn)