Exemplo n.º 1
0
class Protocol:

    def __init__(self, Parser, **kwargs):
        self.url = None
        self.headers = CIMultiDict()
        self.body = b''
        self.headers_complete = False
        self.message_complete = False
        self.parser = Parser(self, **kwargs)
        self.feed_data = self.parser.feed_data

    def on_url(self, url):
        self.url = url

    def on_header(self, name, value):
        self.headers.add(name.decode(), value.decode())

    def on_headers_complete(self):
        self.headers_complete = True

    def on_body(self, body):
        self.body += body

    def on_message_complete(self):
        self.message_complete = True
Exemplo n.º 2
0
def parse_headers(header_data: bytes, value_encoding: str = 'ascii') -> CIMultiDict:
    assert check_argument_types()
    headers = CIMultiDict()
    for line in header_data.rstrip().split(b'\r\n'):
        key, value = line.split(b':', 1)
        key = key.strip().decode('ascii')
        value = value.strip().decode(value_encoding)
        headers.add(key, value)

    return headers
Exemplo n.º 3
0
def test_multiple_forwarded_headers():
    headers = CIMultiDict()
    headers.add('Forwarded', 'By=identifier1;for=identifier2, BY=identifier3')
    headers.add('Forwarded', 'By=identifier4;fOr=identifier5')
    req = make_mocked_request('GET', '/', headers=headers)
    assert len(req.forwarded) == 3
    assert req.forwarded[0]['by'] == 'identifier1'
    assert req.forwarded[0]['for'] == 'identifier2'
    assert req.forwarded[1]['by'] == 'identifier3'
    assert req.forwarded[2]['by'] == 'identifier4'
    assert req.forwarded[2]['for'] == 'identifier5'
Exemplo n.º 4
0
def test_multiple_forwarded_headers_injection():
    headers = CIMultiDict()
    # This could be sent by an attacker, hoping to "shadow" the second header.
    headers.add('Forwarded', 'for=_injected;by="')
    # This is added by our trusted reverse proxy.
    headers.add('Forwarded', 'for=_real;by=_actual_proxy')
    req = make_mocked_request('GET', '/', headers=headers)
    assert len(req.forwarded) == 2
    assert 'by' not in req.forwarded[0]
    assert req.forwarded[1]['for'] == '_real'
    assert req.forwarded[1]['by'] == '_actual_proxy'
Exemplo n.º 5
0
 def _prepare_headers(self, headers):
     """ Add default headers and transform it to CIMultiDict
     """
     # Convert headers to MultiDict
     result = CIMultiDict(self._default_headers)
     if headers:
         if not isinstance(headers, (MultiDictProxy, MultiDict)):
             headers = CIMultiDict(headers)
         added_names = set()
         for key, value in headers.items():
             if key in added_names:
                 result.add(key, value)
             else:
                 result[key] = value
                 added_names.add(key)
     return result
Exemplo n.º 6
0
def test_multiple_forwarded_headers_bad_syntax():
    headers = CIMultiDict()
    headers.add('Forwarded', 'for=_1;by=_2')
    headers.add('Forwarded', 'invalid value')
    headers.add('Forwarded', '')
    headers.add('Forwarded', 'for=_3;by=_4')
    req = make_mocked_request('GET', '/', headers=headers)
    assert len(req.forwarded) == 4
    assert req.forwarded[0]['for'] == '_1'
    assert 'for' not in req.forwarded[1]
    assert 'for' not in req.forwarded[2]
    assert req.forwarded[3]['by'] == '_4'
Exemplo n.º 7
0
def test_multiple_forwarded_headers_bad_syntax():
    headers = CIMultiDict()
    headers.add('Forwarded', 'for=_1;by=_2')
    headers.add('Forwarded', 'invalid value')
    headers.add('Forwarded', '')
    headers.add('Forwarded', 'for=_3;by=_4')
    req = make_mocked_request('GET', '/', headers=headers)
    assert len(req.forwarded) == 4
    assert req.forwarded[0]['for'] == '_1'
    assert 'for' not in req.forwarded[1]
    assert 'for' not in req.forwarded[2]
    assert req.forwarded[3]['by'] == '_4'
Exemplo n.º 8
0
    async def channel_highlights(self, channel):
        # a multi dict allows multiple users to have the same highlight word
        highlight_users = CIMultiDict()
        async for user, highlight in self.cursor(
                """
			SELECT "user", highlight
			FROM highlights
			WHERE
				guild = $1
				AND NOT EXISTS (
					SELECT 1
					FROM blocks
					WHERE
						highlights.user = blocks.user
						AND entity = ANY ($2))
		""", channel.guild.id, (channel.id, getattr(channel.category, 'id', None))):
            highlight_users.add(highlight, user)

        return highlight_users
Exemplo n.º 9
0
def test_multiple_forwarded_headers_bad_syntax() -> None:
    headers = CIMultiDict()
    headers.add("Forwarded", "for=_1;by=_2")
    headers.add("Forwarded", "invalid value")
    headers.add("Forwarded", "")
    headers.add("Forwarded", "for=_3;by=_4")
    req = make_mocked_request("GET", "/", headers=headers)
    assert len(req.forwarded) == 4
    assert req.forwarded[0]["for"] == "_1"
    assert "for" not in req.forwarded[1]
    assert "for" not in req.forwarded[2]
    assert req.forwarded[3]["by"] == "_4"
Exemplo n.º 10
0
class Protocol:
    def __init__(self, Parser, **kwargs):
        self.url = None
        self.headers = CIMultiDict()
        self.body = b''
        self.headers_complete = False
        self.message_complete = False
        self.parser = Parser(self, **kwargs)
        self.feed_data = self.parser.feed_data

    def on_url(self, url):
        self.url = url

    def on_header(self, name, value):
        self.headers.add(name.decode(), value.decode())

    def on_headers_complete(self):
        self.headers_complete = True

    def on_body(self, body):
        self.body += body

    def on_message_complete(self):
        self.message_complete = True
Exemplo n.º 11
0
def p(request):

    #fname = request.match_info.get('fname')
    if 'uuid' not in request.query:
        ps = os.listdir(PHOTO_PATH)
        uuid = random.choice(ps)
        uuid = os.path.splitext(uuid)[0]
    else:
        uuid = request.query['uuid']

    header = CIMultiDict()
    header.add('X-mem-key', uuid)
    img = rds.get(uuid)
    if img:
        logging.info('读取缓存:' + uuid)
        header.add('X-mem-cache', 'HIT')
        return web.Response(body=img,
                            content_type='image/jpeg',
                            headers=header)
    else:
        fp = os.path.join(PHOTO_PATH, uuid)
        if not os.path.exists(fp):
            logging.info('文件不存在:' + uuid)
            header.add('X-mem-cache', 'NONE')
            return web.Response(body=b'',
                                content_type='image/jpeg',
                                headers=header)
        with open(os.path.join(PHOTO_PATH, uuid), 'rb') as fo:
            rs = fo.read()

            mrs = rds.setex(uuid, 60, rs)
            logging.info('写入缓存:' + uuid)
            header.add('X-mem-cache', 'MISS')
            return web.Response(body=rs,
                                content_type='image/jpeg',
                                headers=header)
Exemplo n.º 12
0
    def parse_headers(
            self,
            lines: List[bytes]) -> Tuple['CIMultiDictProxy[str]', RawHeaders]:
        headers = CIMultiDict()  # type: CIMultiDict[str]
        raw_headers = []

        lines_idx = 1
        line = lines[1]
        line_count = len(lines)

        while line:
            # Parse initial header name : value pair.
            try:
                bname, bvalue = line.split(b':', 1)
            except ValueError:
                raise InvalidHeader(line) from None

            bname = bname.strip(b' \t')
            bvalue = bvalue.lstrip()
            if HDRRE.search(bname):
                raise InvalidHeader(bname)
            if len(bname) > self.max_field_size:
                raise LineTooLong(
                    "request header name {}".format(
                        bname.decode("utf8", "xmlcharrefreplace")),
                    str(self.max_field_size), str(len(bname)))

            header_length = len(bvalue)

            # next line
            lines_idx += 1
            line = lines[lines_idx]

            # consume continuation lines
            continuation = line and line[0] in (32, 9)  # (' ', '\t')

            if continuation:
                bvalue_lst = [bvalue]
                while continuation:
                    header_length += len(line)
                    if header_length > self.max_field_size:
                        raise LineTooLong(
                            'request header field {}'.format(
                                bname.decode("utf8", "xmlcharrefreplace")),
                            str(self.max_field_size), str(header_length))
                    bvalue_lst.append(line)

                    # next line
                    lines_idx += 1
                    if lines_idx < line_count:
                        line = lines[lines_idx]
                        if line:
                            continuation = line[0] in (32, 9)  # (' ', '\t')
                    else:
                        line = b''
                        break
                bvalue = b''.join(bvalue_lst)
            else:
                if header_length > self.max_field_size:
                    raise LineTooLong(
                        'request header field {}'.format(
                            bname.decode("utf8", "xmlcharrefreplace")),
                        str(self.max_field_size), str(header_length))

            bvalue = bvalue.strip()
            name = bname.decode('utf-8', 'surrogateescape')
            value = bvalue.decode('utf-8', 'surrogateescape')

            headers.add(name, value)
            raw_headers.append((bname, bvalue))

        return (CIMultiDictProxy(headers), tuple(raw_headers))
Exemplo n.º 13
0
class BaseResponse:
    __slots__ = ("body", "status", "content_type", "headers", "_cookies",
                 "_protocol")

    charset = 'utf-8'
    max_cookie_size = 4093

    def __init__(self,
                 body=None,
                 status=200,
                 headers=None,
                 content_type="text/plain"):
        self.content_type = content_type
        self.body = self._encode_body(body)
        self.status = status
        self.headers = CIMultiDict(headers or {})
        self._cookies = None
        self._protocol = None

    def set_protocol(self, protocol):
        self._protocol = protocol

    def has_protocol(self):
        return self._protocol is not None

    @staticmethod
    def _encode_body(data):
        try:
            if not isinstance(data, bytes):
                return data.encode()
            return data
        except AttributeError:
            return str(data or "").encode()

    def _parse_headers(self):
        headers = b""
        for name, value in self.headers.items():
            try:
                headers += b"%b: %b\r\n" % (
                    name.encode(),
                    value.encode(self.charset),
                )
            except AttributeError:
                headers += b"%b: %b\r\n" % (
                    str(name).encode(),
                    str(value).encode(self.charset),
                )
        return headers

    def set_cookie(self,
                   key,
                   value="",
                   max_age=None,
                   expires=None,
                   path="/",
                   domain=None,
                   secure=False,
                   httponly=False,
                   samesite=None):
        cookie = http.cookies.SimpleCookie()
        cookie[key] = value
        if max_age is not None:
            cookie[key]["max-age"] = max_age  # type: ignore
        if expires is not None:
            cookie[key]["expires"] = expires  # type: ignore
        if path is not None:
            cookie[key]["path"] = path
        if domain is not None:
            cookie[key]["domain"] = domain
        if secure:
            cookie[key]["secure"] = True
        if httponly:
            cookie[key]["httponly"] = True
        if samesite:
            cookie[key]["samesite"] = samesite
        cookie_val = cookie.output(header="").strip()
        self.headers.add('Set-Cookie', cookie_val)

    def delete_cookie(self, key, path="/", domain=None):
        self.set_cookie(key, expires=0, max_age=0, path=path, domain=domain)

    def get_headers(self,
                    version="1.1",
                    keep_alive=False,
                    keep_alive_timeout=None):
        timeout_header = b""
        if keep_alive and keep_alive_timeout is not None:
            timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout
        self.headers["Content-Type"] = self.headers.get(
            "Content-Type", self.content_type)
        if self.status in (304, 412):
            self.headers = remove_entity_headers(self.headers)
        headers = self._parse_headers()
        if self.status is 200:
            description = b"OK"
        else:
            description = STATUS_TEXT.get(self.status, b"UNKNOWN RESPONSE")

        return (b"HTTP/%b %d %b\r\n"
                b"Connection: %b\r\n"
                b"%b"
                b"%b\r\n") % (version.encode(), self.status, description,
                              b"keep-alive" if keep_alive else b"close",
                              timeout_header, headers)

    async def output(self,
                     version="1.1",
                     keep_alive=False,
                     keep_alive_timeout=None):
        if has_message_body(self.status):
            body = self.body
            self.headers["Content-Length"] = self.headers.get(
                "Content-Length", len(self.body))
        else:
            body = b""
        return self.get_headers(version, keep_alive,
                                keep_alive_timeout) + b"%b" % body
Exemplo n.º 14
0
def test_host_by_forwarded_header(make_request):
    headers = CIMultiDict()
    headers.add('Forwarded', 'By=identifier1;for=identifier2, BY=identifier3')
    headers.add('Forwarded', 'by=;for=;host=example.com')
    req = make_request('GET', '/', headers=headers)
    assert req.host == 'example.com'
Exemplo n.º 15
0
class ClientRequest:
    GET_METHODS = {
        hdrs.METH_GET,
        hdrs.METH_HEAD,
        hdrs.METH_OPTIONS,
        hdrs.METH_TRACE,
    }
    POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    ALL_METHODS = GET_METHODS.union(POST_METHODS).union({hdrs.METH_DELETE})

    DEFAULT_HEADERS = {
        hdrs.ACCEPT: '*/*',
        hdrs.ACCEPT_ENCODING: 'gzip, deflate',
    }

    body = b''
    auth = None
    response = None
    response_class = None

    _writer = None  # async task for streaming data
    _continue = None  # waiter future for '100 Continue' response

    # N.B.
    # Adding __del__ method with self._writer closing doesn't make sense
    # because _writer is instance method, thus it keeps a reference to self.
    # Until writer has finished finalizer will not be called.

    def __init__(self, method: str, url: URL, *,
                 params: Optional[Mapping[str, str]]=None,
                 headers: Optional[LooseHeaders]=None,
                 skip_auto_headers: Iterable[str]=frozenset(),
                 data: Any=None,
                 cookies: Optional[LooseCookies]=None,
                 auth: Optional[BasicAuth]=None,
                 version: http.HttpVersion=http.HttpVersion11,
                 compress: Optional[str]=None,
                 chunked: Optional[bool]=None,
                 expect100: bool=False,
                 loop: Optional[asyncio.AbstractEventLoop]=None,
                 response_class: Optional[Type['ClientResponse']]=None,
                 proxy: Optional[URL]=None,
                 proxy_auth: Optional[BasicAuth]=None,
                 timer: Optional[BaseTimerContext]=None,
                 session: Optional['ClientSession']=None,
                 ssl: Union[SSLContext, bool, Fingerprint, None]=None,
                 proxy_headers: Optional[LooseHeaders]=None,
                 traces: Optional[List['Trace']]=None):

        if loop is None:
            loop = asyncio.get_event_loop()

        assert isinstance(url, URL), url
        assert isinstance(proxy, (URL, type(None))), proxy
        # FIXME: session is None in tests only, need to fix tests
        # assert session is not None
        self._session = cast('ClientSession', session)
        if params:
            q = MultiDict(url.query)
            url2 = url.with_query(params)
            q.extend(url2.query)
            url = url.with_query(q)
        self.original_url = url
        self.url = url.with_fragment(None)
        self.method = method.upper()
        self.chunked = chunked
        self.compress = compress
        self.loop = loop
        self.length = None
        if response_class is None:
            real_response_class = ClientResponse
        else:
            real_response_class = response_class
        self.response_class = real_response_class  # type: Type[ClientResponse]
        self._timer = timer if timer is not None else TimerNoop()
        self._ssl = ssl

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        self.update_version(version)
        self.update_host(url)
        self.update_headers(headers)
        self.update_auto_headers(skip_auto_headers)
        self.update_cookies(cookies)
        self.update_content_encoding(data)
        self.update_auth(auth)
        self.update_proxy(proxy, proxy_auth, proxy_headers)

        self.update_body_from_data(data)
        if data or self.method not in self.GET_METHODS:
            self.update_transfer_encoding()
        self.update_expect_continue(expect100)
        if traces is None:
            traces = []
        self._traces = traces

    def is_ssl(self) -> bool:
        return self.url.scheme in ('https', 'wss')

    @property
    def ssl(self) -> Union['SSLContext', None, bool, Fingerprint]:
        return self._ssl

    @property
    def connection_key(self) -> ConnectionKey:
        proxy_headers = self.proxy_headers
        if proxy_headers:
            h = hash(tuple((k, v) for k, v in proxy_headers.items()))  # type: Optional[int]  # noqa
        else:
            h = None
        return ConnectionKey(self.host, self.port, self.is_ssl(),
                             self.ssl,
                             self.proxy, self.proxy_auth, h)

    @property
    def host(self) -> str:
        ret = self.url.host
        assert ret is not None
        return ret

    @property
    def port(self) -> Optional[int]:
        return self.url.port

    @property
    def request_info(self) -> RequestInfo:
        headers = CIMultiDictProxy(self.headers)  # type: CIMultiDictProxy[str]
        return RequestInfo(self.url, self.method,
                           headers, self.original_url)

    def update_host(self, url: URL) -> None:
        """Update destination host, port and connection type (ssl)."""
        # get host/port
        if not url.host:
            raise InvalidURL(url)

        # basic auth info
        username, password = url.user, url.password
        if username:
            self.auth = helpers.BasicAuth(username, password or '')

    def update_version(self, version: Union[http.HttpVersion, str]) -> None:
        """Convert request version to two elements tuple.

        parser HTTP version '1.1' => (1, 1)
        """
        if isinstance(version, str):
            v = [l.strip() for l in version.split('.', 1)]
            try:
                version = http.HttpVersion(int(v[0]), int(v[1]))
            except ValueError:
                raise ValueError(
                    'Can not parse http version number: {}'
                    .format(version)) from None
        self.version = version

    def update_headers(self, headers: Optional[LooseHeaders]) -> None:
        """Update request headers."""
        self.headers = CIMultiDict()  # type: CIMultiDict[str]

        # add host
        netloc = cast(str, self.url.raw_host)
        if helpers.is_ipv6_address(netloc):
            netloc = '[{}]'.format(netloc)
        if not self.url.is_default_port():
            netloc += ':' + str(self.url.port)
        self.headers[hdrs.HOST] = netloc

        if headers:
            if isinstance(headers, (dict, MultiDictProxy, MultiDict)):
                headers = headers.items()  # type: ignore

            for key, value in headers:
                # A special case for Host header
                if key.lower() == 'host':
                    self.headers[key] = value
                else:
                    self.headers.add(key, value)

    def update_auto_headers(self, skip_auto_headers: Iterable[str]) -> None:
        self.skip_auto_headers = CIMultiDict(
            (hdr, None) for hdr in sorted(skip_auto_headers))
        used_headers = self.headers.copy()
        used_headers.extend(self.skip_auto_headers)  # type: ignore

        for hdr, val in self.DEFAULT_HEADERS.items():
            if hdr not in used_headers:
                self.headers.add(hdr, val)

        if hdrs.USER_AGENT not in used_headers:
            self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE

    def update_cookies(self, cookies: Optional[LooseCookies]) -> None:
        """Update request cookies header."""
        if not cookies:
            return

        c = SimpleCookie()
        if hdrs.COOKIE in self.headers:
            c.load(self.headers.get(hdrs.COOKIE, ''))
            del self.headers[hdrs.COOKIE]

        if isinstance(cookies, Mapping):
            iter_cookies = cookies.items()
        else:
            iter_cookies = cookies  # type: ignore
        for name, value in iter_cookies:
            if isinstance(value, Morsel):
                # Preserve coded_value
                mrsl_val = value.get(value.key, Morsel())
                mrsl_val.set(value.key, value.value, value.coded_value)  # type: ignore  # noqa
                c[name] = mrsl_val
            else:
                c[name] = value  # type: ignore

        self.headers[hdrs.COOKIE] = c.output(header='', sep=';').strip()

    def update_content_encoding(self, data: Any) -> None:
        """Set request content encoding."""
        if not data:
            return

        enc = self.headers.get(hdrs.CONTENT_ENCODING, '').lower()
        if enc:
            if self.compress:
                raise ValueError(
                    'compress can not be set '
                    'if Content-Encoding header is set')
        elif self.compress:
            if not isinstance(self.compress, str):
                self.compress = 'deflate'
            self.headers[hdrs.CONTENT_ENCODING] = self.compress
            self.chunked = True  # enable chunked, no need to deal with length

    def update_transfer_encoding(self) -> None:
        """Analyze transfer-encoding header."""
        te = self.headers.get(hdrs.TRANSFER_ENCODING, '').lower()

        if 'chunked' in te:
            if self.chunked:
                raise ValueError(
                    'chunked can not be set '
                    'if "Transfer-Encoding: chunked" header is set')

        elif self.chunked:
            if hdrs.CONTENT_LENGTH in self.headers:
                raise ValueError(
                    'chunked can not be set '
                    'if Content-Length header is set')

            self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'
        else:
            if hdrs.CONTENT_LENGTH not in self.headers:
                self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_auth(self, auth: Optional[BasicAuth]) -> None:
        """Set basic auth."""
        if auth is None:
            auth = self.auth
        if auth is None:
            return

        if not isinstance(auth, helpers.BasicAuth):
            raise TypeError('BasicAuth() tuple is required instead')

        self.headers[hdrs.AUTHORIZATION] = auth.encode()

    def update_body_from_data(self, body: Any) -> None:
        if not body:
            return

        # FormData
        if isinstance(body, FormData):
            body = body()

        try:
            body = payload.PAYLOAD_REGISTRY.get(body, disposition=None)
        except payload.LookupError:
            body = FormData(body)()

        self.body = body

        # enable chunked encoding if needed
        if not self.chunked:
            if hdrs.CONTENT_LENGTH not in self.headers:
                size = body.size
                if size is None:
                    self.chunked = True
                else:
                    if hdrs.CONTENT_LENGTH not in self.headers:
                        self.headers[hdrs.CONTENT_LENGTH] = str(size)

        # set content-type
        if (hdrs.CONTENT_TYPE not in self.headers and
                hdrs.CONTENT_TYPE not in self.skip_auto_headers):
            self.headers[hdrs.CONTENT_TYPE] = body.content_type

        # copy payload headers
        if body.headers:
            for (key, value) in body.headers.items():
                if key not in self.headers:
                    self.headers[key] = value

    def update_expect_continue(self, expect: bool=False) -> None:
        if expect:
            self.headers[hdrs.EXPECT] = '100-continue'
        elif self.headers.get(hdrs.EXPECT, '').lower() == '100-continue':
            expect = True

        if expect:
            self._continue = self.loop.create_future()

    def update_proxy(self, proxy: Optional[URL],
                     proxy_auth: Optional[BasicAuth],
                     proxy_headers: Optional[LooseHeaders]) -> None:
        if proxy and not proxy.scheme == 'http':
            raise ValueError("Only http proxies are supported")
        if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
            raise ValueError("proxy_auth must be None or BasicAuth() tuple")
        self.proxy = proxy
        self.proxy_auth = proxy_auth
        self.proxy_headers = proxy_headers

    def keep_alive(self) -> bool:
        if self.version < HttpVersion10:
            # keep alive not supported at all
            return False
        if self.version == HttpVersion10:
            if self.headers.get(hdrs.CONNECTION) == 'keep-alive':
                return True
            else:  # no headers means we close for Http 1.0
                return False
        elif self.headers.get(hdrs.CONNECTION) == 'close':
            return False

        return True

    async def write_bytes(self, writer: AbstractStreamWriter,
                          conn: 'Connection') -> None:
        """Support coroutines that yields bytes objects."""
        # 100 response
        if self._continue is not None:
            await writer.drain()
            await self._continue

        protocol = conn.protocol
        assert protocol is not None
        try:
            if isinstance(self.body, payload.Payload):
                await self.body.write(writer)
            else:
                if isinstance(self.body, (bytes, bytearray)):
                    self.body = (self.body,)  # type: ignore

                for chunk in self.body:
                    await writer.write(chunk)  # type: ignore

            await writer.write_eof()
        except OSError as exc:
            new_exc = ClientOSError(
                exc.errno,
                'Can not write request body for %s' % self.url)
            new_exc.__context__ = exc
            new_exc.__cause__ = exc
            protocol.set_exception(new_exc)
        except asyncio.CancelledError as exc:
            if not conn.closed:
                protocol.set_exception(exc)
        except Exception as exc:
            protocol.set_exception(exc)
        finally:
            self._writer = None

    async def send(self, conn: 'Connection') -> 'ClientResponse':
        # Specify request target:
        # - CONNECT request must send authority form URI
        # - not CONNECT proxy must send absolute form URI
        # - most common is origin form URI
        if self.method == hdrs.METH_CONNECT:
            path = '{}:{}'.format(self.url.raw_host, self.url.port)
        elif self.proxy and not self.is_ssl():
            path = str(self.url)
        else:
            path = self.url.raw_path
            if self.url.raw_query_string:
                path += '?' + self.url.raw_query_string

        protocol = conn.protocol
        assert protocol is not None
        writer = StreamWriter(
            protocol, self.loop,
            on_chunk_sent=self._on_chunk_request_sent
        )

        if self.compress:
            writer.enable_compression(self.compress)

        if self.chunked is not None:
            writer.enable_chunking()

        # set default content-type
        if (self.method in self.POST_METHODS and
                hdrs.CONTENT_TYPE not in self.skip_auto_headers and
                hdrs.CONTENT_TYPE not in self.headers):
            self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'

        # set the connection header
        connection = self.headers.get(hdrs.CONNECTION)
        if not connection:
            if self.keep_alive():
                if self.version == HttpVersion10:
                    connection = 'keep-alive'
            else:
                if self.version == HttpVersion11:
                    connection = 'close'

        if connection is not None:
            self.headers[hdrs.CONNECTION] = connection

        # status + headers
        status_line = '{0} {1} HTTP/{2[0]}.{2[1]}'.format(
            self.method, path, self.version)
        await writer.write_headers(status_line, self.headers)

        self._writer = self.loop.create_task(self.write_bytes(writer, conn))

        response_class = self.response_class
        assert response_class is not None
        self.response = response_class(
            self.method, self.original_url,
            writer=self._writer, continue100=self._continue, timer=self._timer,
            request_info=self.request_info,
            traces=self._traces,
            loop=self.loop,
            session=self._session
        )
        return self.response

    async def close(self) -> None:
        if self._writer is not None:
            try:
                await self._writer
            finally:
                self._writer = None

    def terminate(self) -> None:
        if self._writer is not None:
            if not self.loop.is_closed():
                self._writer.cancel()
            self._writer = None

    async def _on_chunk_request_sent(self, chunk: bytes) -> None:
        for trace in self._traces:
            await trace.send_request_chunk_sent(chunk)
Exemplo n.º 16
0
class HttpMessage(PayloadWriter):
    """HttpMessage allows to write headers and payload to a stream."""

    HOP_HEADERS = None  # Must be set by subclass.

    SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} aiohttp/{1}'.format(
        sys.version_info, aiohttp.__version__)

    upgrade = False  # Connection: UPGRADE
    websocket = False  # Upgrade: WEBSOCKET
    has_chunked_hdr = False  # Transfer-encoding: chunked

    def __init__(self, transport, version, close, loop=None):
        super().__init__(transport, loop)

        self.version = version
        self.closing = close
        self.keepalive = None
        self.length = None
        self.headers = CIMultiDict()
        self.headers_sent = False

    @property
    def body_length(self):
        return self.output_length

    def force_close(self):
        self.closing = True
        self.keepalive = False

    def keep_alive(self):
        if self.keepalive is None:
            if self.version < HttpVersion10:
                # keep alive not supported at all
                return False
            if self.version == HttpVersion10:
                if self.headers.get(hdrs.CONNECTION) == 'keep-alive':
                    return True
                else:  # no headers means we close for Http 1.0
                    return False
            else:
                return not self.closing
        else:
            return self.keepalive

    def is_headers_sent(self):
        return self.headers_sent

    def add_header(self, name, value):
        """Analyze headers. Calculate content length,
        removes hop headers, etc."""
        assert not self.headers_sent, 'headers have been sent already'
        assert isinstance(name, str), \
            'Header name should be a string, got {!r}'.format(name)
        assert set(name).issubset(ASCIISET), \
            'Header name should contain ASCII chars, got {!r}'.format(name)
        assert isinstance(value, str), \
            'Header {!r} should have string value, got {!r}'.format(
                name, value)

        name = istr(name)
        value = value.strip()

        if name == hdrs.CONTENT_LENGTH:
            self.length = int(value)

        if name == hdrs.TRANSFER_ENCODING:
            self.has_chunked_hdr = value.lower() == 'chunked'

        if name == hdrs.CONNECTION:
            val = value.lower()
            # handle websocket
            if 'upgrade' in val:
                self.upgrade = True
            # connection keep-alive
            elif 'close' in val:
                self.keepalive = False
            elif 'keep-alive' in val:
                self.keepalive = True

        elif name == hdrs.UPGRADE:
            if 'websocket' in value.lower():
                self.websocket = True
            self.headers[name] = value

        elif name not in self.HOP_HEADERS:
            # ignore hop-by-hop headers
            self.headers.add(name, value)

    def add_headers(self, *headers):
        """Adds headers to a HTTP message."""
        for name, value in headers:
            self.add_header(name, value)

    def send_headers(self, _sep=': ', _end='\r\n'):
        """Writes headers to a stream. Constructs payload writer."""
        # Chunked response is only for HTTP/1.1 clients or newer
        # and there is no Content-Length header is set.
        # Do not use chunked responses when the response is guaranteed to
        # not have a response body (304, 204).
        assert not self.headers_sent, 'headers have been sent already'
        self.headers_sent = True

        if not self.chunked and self.autochunked():
            self.enable_chunking()

        if self.chunked:
            self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'

        self._add_default_headers()

        # status + headers
        headers = self.status_line + ''.join(
            [k + _sep + v + _end for k, v in self.headers.items()])
        headers = headers.encode('utf-8') + b'\r\n'

        self.buffer_data(headers)

    def _add_default_headers(self):
        # set the connection header
        connection = None
        if self.upgrade:
            connection = 'Upgrade'
        elif not self.closing if self.keepalive is None else self.keepalive:
            if self.version == HttpVersion10:
                connection = 'keep-alive'
        else:
            if self.version == HttpVersion11:
                connection = 'close'

        if connection is not None:
            self.headers[hdrs.CONNECTION] = connection
Exemplo n.º 17
0
class ClientRequest:

    GET_METHODS = {hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS}
    POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    ALL_METHODS = GET_METHODS.union(POST_METHODS).union(
        {hdrs.METH_DELETE, hdrs.METH_TRACE})

    DEFAULT_HEADERS = {
        hdrs.ACCEPT: '*/*',
        hdrs.ACCEPT_ENCODING: 'gzip, deflate',
    }

    body = b''
    auth = None
    response = None
    response_class = None

    _writer = None  # async task for streaming data
    _continue = None  # waiter future for '100 Continue' response

    # N.B.
    # Adding __del__ method with self._writer closing doesn't make sense
    # because _writer is instance method, thus it keeps a reference to self.
    # Until writer has finished finalizer will not be called.

    def __init__(self, method, url, *,
                 params=None, headers=None, skip_auto_headers=frozenset(),
                 data=None, cookies=None,
                 auth=None, version=http.HttpVersion11, compress=None,
                 chunked=None, expect100=False,
                 loop=None, response_class=None,
                 proxy=None, proxy_auth=None, proxy_from_env=False,
                 timer=None, session=None, auto_decompress=True):

        if loop is None:
            loop = asyncio.get_event_loop()

        assert isinstance(url, URL), url
        assert isinstance(proxy, (URL, type(None))), proxy
        self._session = session
        if params:
            q = MultiDict(url.query)
            url2 = url.with_query(params)
            q.extend(url2.query)
            url = url.with_query(q)
        self.url = url.with_fragment(None)
        self.original_url = url
        self.method = method.upper()
        self.chunked = chunked
        self.compress = compress
        self.loop = loop
        self.length = None
        self.response_class = response_class or ClientResponse
        self._timer = timer if timer is not None else TimerNoop()
        self._auto_decompress = auto_decompress

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        self.update_version(version)
        self.update_host(url)
        self.update_headers(headers)
        self.update_auto_headers(skip_auto_headers)
        self.update_cookies(cookies)
        self.update_content_encoding(data)
        self.update_auth(auth)
        self.update_proxy(proxy, proxy_auth, proxy_from_env)

        self.update_body_from_data(data)
        self.update_transfer_encoding()
        self.update_expect_continue(expect100)

    @property
    def host(self):
        return self.url.host

    @property
    def port(self):
        return self.url.port

    @property
    def request_info(self):
        return RequestInfo(self.url, self.method, self.headers)

    def update_host(self, url):
        """Update destination host, port and connection type (ssl)."""
        # get host/port
        if not url.host:
            raise ValueError(
                "Could not parse hostname from URL '{}'".format(url))

        # basic auth info
        username, password = url.user, url.password
        if username:
            self.auth = helpers.BasicAuth(username, password or '')

        # Record entire netloc for usage in host header

        scheme = url.scheme
        self.ssl = scheme in ('https', 'wss')

    def update_version(self, version):
        """Convert request version to two elements tuple.

        parser HTTP version '1.1' => (1, 1)
        """
        if isinstance(version, str):
            v = [l.strip() for l in version.split('.', 1)]
            try:
                version = int(v[0]), int(v[1])
            except ValueError:
                raise ValueError(
                    'Can not parse http version number: {}'
                    .format(version)) from None
        self.version = version

    def update_headers(self, headers):
        """Update request headers."""
        self.headers = CIMultiDict()
        if headers:
            if isinstance(headers, (dict, MultiDictProxy, MultiDict)):
                headers = headers.items()

            for key, value in headers:
                self.headers.add(key, value)

    def update_auto_headers(self, skip_auto_headers):
        self.skip_auto_headers = CIMultiDict(
            (hdr, None) for hdr in sorted(skip_auto_headers))
        used_headers = self.headers.copy()
        used_headers.extend(self.skip_auto_headers)

        for hdr, val in self.DEFAULT_HEADERS.items():
            if hdr not in used_headers:
                self.headers.add(hdr, val)

        # add host
        if hdrs.HOST not in used_headers:
            netloc = self.url.raw_host
            if not self.url.is_default_port():
                netloc += ':' + str(self.url.port)
            self.headers[hdrs.HOST] = netloc

        if hdrs.USER_AGENT not in used_headers:
            self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE

    def update_cookies(self, cookies):
        """Update request cookies header."""
        if not cookies:
            return

        c = SimpleCookie()
        if hdrs.COOKIE in self.headers:
            c.load(self.headers.get(hdrs.COOKIE, ''))
            del self.headers[hdrs.COOKIE]

        for name, value in cookies.items():
            if isinstance(value, Morsel):
                # Preserve coded_value
                mrsl_val = value.get(value.key, Morsel())
                mrsl_val.set(value.key, value.value, value.coded_value)
                c[name] = mrsl_val
            else:
                c[name] = value

        self.headers[hdrs.COOKIE] = c.output(header='', sep=';').strip()

    def update_content_encoding(self, data):
        """Set request content encoding."""
        if not data:
            return

        enc = self.headers.get(hdrs.CONTENT_ENCODING, '').lower()
        if enc:
            if self.compress:
                raise ValueError(
                    'compress can not be set '
                    'if Content-Encoding header is set')
        elif self.compress:
            if not isinstance(self.compress, str):
                self.compress = 'deflate'
            self.headers[hdrs.CONTENT_ENCODING] = self.compress
            self.chunked = True  # enable chunked, no need to deal with length

    def update_transfer_encoding(self):
        """Analyze transfer-encoding header."""
        te = self.headers.get(hdrs.TRANSFER_ENCODING, '').lower()

        if 'chunked' in te:
            if self.chunked:
                raise ValueError(
                    'chunked can not be set '
                    'if "Transfer-Encoding: chunked" header is set')

        elif self.chunked:
            if hdrs.CONTENT_LENGTH in self.headers:
                raise ValueError(
                    'chunked can not be set '
                    'if Content-Length header is set')

            self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'
        else:
            if hdrs.CONTENT_LENGTH not in self.headers:
                self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_auth(self, auth):
        """Set basic auth."""
        if auth is None:
            auth = self.auth
        if auth is None:
            return

        if not isinstance(auth, helpers.BasicAuth):
            raise TypeError('BasicAuth() tuple is required instead')

        self.headers[hdrs.AUTHORIZATION] = auth.encode()

    def update_body_from_data(self, body):
        if not body:
            return

        # FormData
        if isinstance(body, FormData):
            body = body()

        try:
            body = payload.PAYLOAD_REGISTRY.get(body, disposition=None)
        except payload.LookupError:
            body = FormData(body)()

        self.body = body

        # enable chunked encoding if needed
        if not self.chunked:
            if hdrs.CONTENT_LENGTH not in self.headers:
                size = body.size
                if size is None:
                    self.chunked = True
                else:
                    if hdrs.CONTENT_LENGTH not in self.headers:
                        self.headers[hdrs.CONTENT_LENGTH] = str(size)

        # set content-type
        if (hdrs.CONTENT_TYPE not in self.headers and
                hdrs.CONTENT_TYPE not in self.skip_auto_headers):
            self.headers[hdrs.CONTENT_TYPE] = body.content_type

        # copy payload headers
        if body.headers:
            for (key, value) in body.headers.items():
                if key not in self.headers:
                    self.headers[key] = value

    def update_expect_continue(self, expect=False):
        if expect:
            self.headers[hdrs.EXPECT] = '100-continue'
        elif self.headers.get(hdrs.EXPECT, '').lower() == '100-continue':
            expect = True

        if expect:
            self._continue = helpers.create_future(self.loop)

    def update_proxy(self, proxy, proxy_auth, proxy_from_env):
        if proxy_from_env and not proxy:
            proxy_url = getproxies().get(self.original_url.scheme)
            proxy = URL(proxy_url) if proxy_url else None
        if proxy and not proxy.scheme == 'http':
            raise ValueError("Only http proxies are supported")
        if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
            raise ValueError("proxy_auth must be None or BasicAuth() tuple")
        self.proxy = proxy
        self.proxy_auth = proxy_auth

    def keep_alive(self):
        if self.version < HttpVersion10:
            # keep alive not supported at all
            return False
        if self.version == HttpVersion10:
            if self.headers.get(hdrs.CONNECTION) == 'keep-alive':
                return True
            else:  # no headers means we close for Http 1.0
                return False
        elif self.headers.get(hdrs.CONNECTION) == 'close':
            return False

        return True

    @asyncio.coroutine
    def write_bytes(self, writer, conn):
        """Support coroutines that yields bytes objects."""
        # 100 response
        if self._continue is not None:
            yield from writer.drain()
            yield from self._continue

        try:
            if isinstance(self.body, payload.Payload):
                yield from self.body.write(writer)
            else:
                if isinstance(self.body, (bytes, bytearray)):
                    self.body = (self.body,)

                for chunk in self.body:
                    writer.write(chunk)

            yield from writer.write_eof()
        except OSError as exc:
            new_exc = ClientOSError(
                exc.errno,
                'Can not write request body for %s' % self.url)
            new_exc.__context__ = exc
            new_exc.__cause__ = exc
            conn.protocol.set_exception(new_exc)
        except asyncio.CancelledError as exc:
            if not conn.closed:
                conn.protocol.set_exception(exc)
        except Exception as exc:
            conn.protocol.set_exception(exc)
        finally:
            self._writer = None

    def send(self, conn):
        # Specify request target:
        # - CONNECT request must send authority form URI
        # - not CONNECT proxy must send absolute form URI
        # - most common is origin form URI
        if self.method == hdrs.METH_CONNECT:
            path = '{}:{}'.format(self.url.raw_host, self.url.port)
        elif self.proxy and not self.ssl:
            path = str(self.url)
        else:
            path = self.url.raw_path
            if self.url.raw_query_string:
                path += '?' + self.url.raw_query_string

        writer = PayloadWriter(conn.writer, self.loop)

        if self.compress:
            writer.enable_compression(self.compress)

        if self.chunked is not None:
            writer.enable_chunking()

        # set default content-type
        if (self.method in self.POST_METHODS and
                hdrs.CONTENT_TYPE not in self.skip_auto_headers and
                hdrs.CONTENT_TYPE not in self.headers):
            self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'

        # set the connection header
        connection = self.headers.get(hdrs.CONNECTION)
        if not connection:
            if self.keep_alive():
                if self.version == HttpVersion10:
                    connection = 'keep-alive'
            else:
                if self.version == HttpVersion11:
                    connection = 'close'

        if connection is not None:
            self.headers[hdrs.CONNECTION] = connection

        # status + headers
        status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format(
            self.method, path, self.version)
        writer.write_headers(status_line, self.headers)

        self._writer = helpers.ensure_future(
            self.write_bytes(writer, conn), loop=self.loop)

        self.response = self.response_class(
            self.method, self.original_url,
            writer=self._writer, continue100=self._continue, timer=self._timer,
            request_info=self.request_info,
            auto_decompress=self._auto_decompress
        )

        self.response._post_init(self.loop, self._session)
        return self.response

    @asyncio.coroutine
    def close(self):
        if self._writer is not None:
            try:
                yield from self._writer
            finally:
                self._writer = None

    def terminate(self):
        if self._writer is not None:
            if not self.loop.is_closed():
                self._writer.cancel()
            self._writer = None
Exemplo n.º 18
0
def test(request):
    header = CIMultiDict()
    header.add('Access-Control-Allow-Origin', '*')
    return web.json_response(data={'code':0,'msg':'ok','data':'all ok'},headers=header)
def function10(function145):
    var2518 = CIMultiDict()
    var2518.add('Forwarded', 'By=identifier1;for=identifier2, BY=identifier3')
    var2518.add('Forwarded', 'by=;for=;host=example.com')
    var3961 = function145('GET', '/', headers=var2518)
    assert (var3961.host == 'example.com')
Exemplo n.º 20
0
def test_host_by_forwarded_header(make_request):
    headers = CIMultiDict()
    headers.add('Forwarded', 'By=identifier1;for=identifier2, BY=identifier3')
    headers.add('Forwarded', 'by=;for=;host=example.com')
    req = make_request('GET', '/', headers=headers)
    assert req.host == 'example.com'
Exemplo n.º 21
0
async def oauth2_callback(request: Request):

    host = request.headers["host"]
    if "X-Forwarded-Ssl" in request.headers and request.headers[
            "X-Forwarded-Ssl"] == 'on':
        protocol = "https"
    else:
        protocol = request.url.scheme

    cid = "ceiamin"

    callback = protocol + "://" + host + "/frontend/oauth2Callback/"

    state = request.query_params['state']
    #session_state = request.rel_url.query['session_state']
    saved_state = request.cookies["ceiamin_oauth_state"]
    code = request.query_params['code']

    form = aiohttp.FormData({
        "grant_type": "authorization_code",
        "code": code,
        "client_id": cid,
        "redirect_uri": callback
    })

    resp: ClientResponse = None
    try:
        if state != saved_state:
            raise Exception("State " + state + " != " + saved_state)
        openClientSession: ClientSession = aiohttp.ClientSession(
            connector=aiohttp.TCPConnector(ssl=(settings.TEST == 0)))
        resp: ClientResponse = await openClientSession.post(
            settings.OAUTH2_TOKEN_URL, data=form())
        body = await resp.read()
        print("Corpo da resposta HTTP para obter token")
        print(body)
        jsonToken = json.loads(body)
        key = "Bearer " + jsonToken["access_token"]
        aheaders = CIMultiDict()
        aheaders.add("Authorization", key)
        userInfo: ClientResponse = await openClientSession.get(
            settings.OAUTH2_USERINFO_URL, headers=aheaders)
        body = await userInfo.read()
        jsonUser = json.loads(body)

    except:
        traceback.print_exc()
        res = PlainTextResponse("Autorização falhou.", status_code=403)
        return res
    finally:
        if resp:
            resp.close()

    sessionId = str(uuid.uuid1())
    encoded_jwt = jwt.encode(
        {
            'session': sessionId,
            'user': jsonUser["preferred_username"],
            "timeini": str(datetime.now().timestamp())
        },
        secret,
        algorithm='HS256')
    if isinstance(encoded_jwt, str):
        encoded_jwt_str = encoded_jwt
    else:
        encoded_jwt_str = encoded_jwt.decode("utf-8")

    response = RedirectResponse("/frontend")
    response.set_cookie("ceiaminsession", encoded_jwt_str, path="/")

    return response
Exemplo n.º 22
0
 def readDatabaseMaps(self):
     maps = CIMultiDict()
     maps.add("tableInfo", self.analyzeTables())
     return maps
Exemplo n.º 23
0
class ClientRequest:

    GET_METHODS = {hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS}
    POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    ALL_METHODS = GET_METHODS.union(POST_METHODS).union({hdrs.METH_DELETE, hdrs.METH_TRACE})

    DEFAULT_HEADERS = {hdrs.ACCEPT: "*/*", hdrs.ACCEPT_ENCODING: "gzip, deflate"}

    SERVER_SOFTWARE = HttpMessage.SERVER_SOFTWARE

    body = b""
    auth = None
    response = None
    response_class = None

    _writer = None  # async task for streaming data
    _continue = None  # waiter future for '100 Continue' response

    # N.B.
    # Adding __del__ method with self._writer closing doesn't make sense
    # because _writer is instance method, thus it keeps a reference to self.
    # Until writer has finished finalizer will not be called.

    def __init__(
        self,
        method,
        url,
        *,
        params=None,
        headers=None,
        skip_auto_headers=frozenset(),
        data=None,
        cookies=None,
        auth=None,
        encoding="utf-8",
        version=aiohttp.HttpVersion11,
        compress=None,
        chunked=None,
        expect100=False,
        loop=None,
        response_class=None
    ):

        if loop is None:
            loop = asyncio.get_event_loop()

        self.url = url
        self.method = method.upper()
        self.encoding = encoding
        self.chunked = chunked
        self.compress = compress
        self.loop = loop
        self.response_class = response_class or ClientResponse

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        self.update_version(version)
        self.update_host(url)
        self.update_path(params)
        self.update_headers(headers)
        self.update_auto_headers(skip_auto_headers)
        self.update_cookies(cookies)
        self.update_content_encoding()
        self.update_auth(auth)

        self.update_body_from_data(data, skip_auto_headers)
        self.update_transfer_encoding()
        self.update_expect_continue(expect100)

    def update_host(self, url):
        """Update destination host, port and connection type (ssl)."""
        url_parsed = urllib.parse.urlsplit(url)

        # check for network location part
        netloc = url_parsed.netloc
        if not netloc:
            raise ValueError("Host could not be detected.")

        # get host/port
        host = url_parsed.hostname
        if not host:
            raise ValueError("Host could not be detected.")

        try:
            port = url_parsed.port
        except ValueError:
            raise ValueError("Port number could not be converted.") from None

        # check domain idna encoding
        try:
            netloc = netloc.encode("idna").decode("utf-8")
            host = host.encode("idna").decode("utf-8")
        except UnicodeError:
            raise ValueError("URL has an invalid label.")

        # basic auth info
        username, password = url_parsed.username, url_parsed.password
        if username:
            self.auth = helpers.BasicAuth(username, password or "")
            netloc = netloc.split("@", 1)[1]

        # Record entire netloc for usage in host header
        self.netloc = netloc

        scheme = url_parsed.scheme
        self.ssl = scheme in ("https", "wss")

        # set port number if it isn't already set
        if not port:
            if self.ssl:
                port = HTTPS_PORT
            else:
                port = HTTP_PORT

        self.host, self.port, self.scheme = host, port, scheme

    def update_version(self, version):
        """Convert request version to two elements tuple.

        parser HTTP version '1.1' => (1, 1)
        """
        if isinstance(version, str):
            v = [l.strip() for l in version.split(".", 1)]
            try:
                version = int(v[0]), int(v[1])
            except ValueError:
                raise ValueError("Can not parse http version number: {}".format(version)) from None
        self.version = version

    def update_path(self, params):
        """Build path."""
        # extract path
        scheme, netloc, path, query, fragment = urllib.parse.urlsplit(self.url)
        if not path:
            path = "/"

        if isinstance(params, collections.Mapping):
            params = list(params.items())

        if params:
            if not isinstance(params, str):
                params = urllib.parse.urlencode(params)
            if query:
                query = "%s&%s" % (query, params)
            else:
                query = params

        self.path = urllib.parse.urlunsplit(("", "", helpers.requote_uri(path), query, fragment))
        self.url = urllib.parse.urlunsplit((scheme, netloc, self.path, "", ""))

    def update_headers(self, headers):
        """Update request headers."""
        self.headers = CIMultiDict()
        if headers:
            if isinstance(headers, dict):
                headers = headers.items()
            elif isinstance(headers, (MultiDictProxy, MultiDict)):
                headers = headers.items()

            for key, value in headers:
                self.headers.add(key, value)

    def update_auto_headers(self, skip_auto_headers):
        self.skip_auto_headers = skip_auto_headers
        used_headers = set(self.headers) | skip_auto_headers

        for hdr, val in self.DEFAULT_HEADERS.items():
            if hdr not in used_headers:
                self.headers.add(hdr, val)

        # add host
        if hdrs.HOST not in used_headers:
            self.headers[hdrs.HOST] = self.netloc

        if hdrs.USER_AGENT not in used_headers:
            self.headers[hdrs.USER_AGENT] = self.SERVER_SOFTWARE

    def update_cookies(self, cookies):
        """Update request cookies header."""
        if not cookies:
            return

        c = http.cookies.SimpleCookie()
        if hdrs.COOKIE in self.headers:
            c.load(self.headers.get(hdrs.COOKIE, ""))
            del self.headers[hdrs.COOKIE]

        if isinstance(cookies, dict):
            cookies = cookies.items()

        for name, value in cookies:
            if isinstance(value, http.cookies.Morsel):
                c[value.key] = value.value
            else:
                c[name] = value

        self.headers[hdrs.COOKIE] = c.output(header="", sep=";").strip()

    def update_content_encoding(self):
        """Set request content encoding."""
        enc = self.headers.get(hdrs.CONTENT_ENCODING, "").lower()
        if enc:
            if self.compress is not False:
                self.compress = enc
                # enable chunked, no need to deal with length
                self.chunked = True
        elif self.compress:
            if not isinstance(self.compress, str):
                self.compress = "deflate"
            self.headers[hdrs.CONTENT_ENCODING] = self.compress
            self.chunked = True  # enable chunked, no need to deal with length

    def update_auth(self, auth):
        """Set basic auth."""
        if auth is None:
            auth = self.auth
        if auth is None:
            return

        if not isinstance(auth, helpers.BasicAuth):
            warnings.warn("BasicAuth() tuple is required instead ", DeprecationWarning)
            auth = helpers.BasicAuth(*auth)

        self.headers[hdrs.AUTHORIZATION] = auth.encode()

    def update_body_from_data(self, data, skip_auto_headers):
        if not data:
            return

        if isinstance(data, str):
            data = data.encode(self.encoding)

        if isinstance(data, (bytes, bytearray)):
            self.body = data
            if hdrs.CONTENT_TYPE not in self.headers and hdrs.CONTENT_TYPE not in skip_auto_headers:
                self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream"
            if hdrs.CONTENT_LENGTH not in self.headers and not self.chunked:
                self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

        elif isinstance(data, (asyncio.StreamReader, streams.DataQueue)):
            self.body = data

        elif asyncio.iscoroutine(data):
            self.body = data
            if hdrs.CONTENT_LENGTH not in self.headers and self.chunked is None:
                self.chunked = True

        elif isinstance(data, io.IOBase):
            assert not isinstance(data, io.StringIO), "attempt to send text data instead of binary"
            self.body = data
            if not self.chunked and isinstance(data, io.BytesIO):
                # Not chunking if content-length can be determined
                size = len(data.getbuffer())
                self.headers[hdrs.CONTENT_LENGTH] = str(size)
                self.chunked = False
            elif not self.chunked and isinstance(data, io.BufferedReader):
                # Not chunking if content-length can be determined
                try:
                    size = os.fstat(data.fileno()).st_size - data.tell()
                    self.headers[hdrs.CONTENT_LENGTH] = str(size)
                    self.chunked = False
                except OSError:
                    # data.fileno() is not supported, e.g.
                    # io.BufferedReader(io.BytesIO(b'data'))
                    self.chunked = True
            else:
                self.chunked = True

            if hasattr(data, "mode"):
                if data.mode == "r":
                    raise ValueError("file {!r} should be open in binary mode" "".format(data))
            if (
                hdrs.CONTENT_TYPE not in self.headers
                and hdrs.CONTENT_TYPE not in skip_auto_headers
                and hasattr(data, "name")
            ):
                mime = mimetypes.guess_type(data.name)[0]
                mime = "application/octet-stream" if mime is None else mime
                self.headers[hdrs.CONTENT_TYPE] = mime

        elif isinstance(data, MultipartWriter):
            self.body = data.serialize()
            self.headers.update(data.headers)
            self.chunked = self.chunked or 8192

        else:
            if not isinstance(data, helpers.FormData):
                data = helpers.FormData(data)

            self.body = data(self.encoding)

            if hdrs.CONTENT_TYPE not in self.headers and hdrs.CONTENT_TYPE not in skip_auto_headers:
                self.headers[hdrs.CONTENT_TYPE] = data.content_type

            if data.is_multipart:
                self.chunked = self.chunked or 8192
            else:
                if hdrs.CONTENT_LENGTH not in self.headers and not self.chunked:
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_transfer_encoding(self):
        """Analyze transfer-encoding header."""
        te = self.headers.get(hdrs.TRANSFER_ENCODING, "").lower()

        if self.chunked:
            if hdrs.CONTENT_LENGTH in self.headers:
                del self.headers[hdrs.CONTENT_LENGTH]
            if "chunked" not in te:
                self.headers[hdrs.TRANSFER_ENCODING] = "chunked"

            self.chunked = self.chunked if type(self.chunked) is int else 8192
        else:
            if "chunked" in te:
                self.chunked = 8192
            else:
                self.chunked = None
                if hdrs.CONTENT_LENGTH not in self.headers:
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_expect_continue(self, expect=False):
        if expect:
            self.headers[hdrs.EXPECT] = "100-continue"
        elif self.headers.get(hdrs.EXPECT, "").lower() == "100-continue":
            expect = True

        if expect:
            self._continue = asyncio.Future(loop=self.loop)

    @asyncio.coroutine
    def write_bytes(self, request, reader):
        """Support coroutines that yields bytes objects."""
        # 100 response
        if self._continue is not None:
            yield from self._continue

        try:
            if asyncio.iscoroutine(self.body):
                request.transport.set_tcp_nodelay(True)
                exc = None
                value = None
                stream = self.body

                while True:
                    try:
                        if exc is not None:
                            result = stream.throw(exc)
                        else:
                            result = stream.send(value)
                    except StopIteration as exc:
                        if isinstance(exc.value, bytes):
                            yield from request.write(exc.value, drain=True)
                        break
                    except:
                        self.response.close()
                        raise

                    if isinstance(result, asyncio.Future):
                        exc = None
                        value = None
                        try:
                            value = yield result
                        except Exception as err:
                            exc = err
                    elif isinstance(result, (bytes, bytearray)):
                        yield from request.write(result, drain=True)
                        value = None
                    else:
                        raise ValueError("Bytes object is expected, got: %s." % type(result))

            elif isinstance(self.body, asyncio.StreamReader):
                request.transport.set_tcp_nodelay(True)
                chunk = yield from self.body.read(streams.DEFAULT_LIMIT)
                while chunk:
                    yield from request.write(chunk, drain=True)
                    chunk = yield from self.body.read(streams.DEFAULT_LIMIT)

            elif isinstance(self.body, streams.DataQueue):
                request.transport.set_tcp_nodelay(True)
                while True:
                    try:
                        chunk = yield from self.body.read()
                        if chunk is EOF_MARKER:
                            break
                        yield from request.write(chunk, drain=True)
                    except streams.EofStream:
                        break

            elif isinstance(self.body, io.IOBase):
                chunk = self.body.read(self.chunked)
                while chunk:
                    request.write(chunk)
                    chunk = self.body.read(self.chunked)
                request.transport.set_tcp_nodelay(True)

            else:
                if isinstance(self.body, (bytes, bytearray)):
                    self.body = (self.body,)

                for chunk in self.body:
                    request.write(chunk)
                request.transport.set_tcp_nodelay(True)

        except Exception as exc:
            new_exc = aiohttp.ClientRequestError("Can not write request body for %s" % self.url)
            new_exc.__context__ = exc
            new_exc.__cause__ = exc
            reader.set_exception(new_exc)
        else:
            assert request.transport.tcp_nodelay
            try:
                ret = request.write_eof()
                # NB: in asyncio 3.4.1+ StreamWriter.drain() is coroutine
                # see bug #170
                if asyncio.iscoroutine(ret) or isinstance(ret, asyncio.Future):
                    yield from ret
            except Exception as exc:
                new_exc = aiohttp.ClientRequestError("Can not write request body for %s" % self.url)
                new_exc.__context__ = exc
                new_exc.__cause__ = exc
                reader.set_exception(new_exc)

        self._writer = None

    def send(self, writer, reader):
        writer.set_tcp_cork(True)
        request = aiohttp.Request(writer, self.method, self.path, self.version)

        if self.compress:
            request.add_compression_filter(self.compress)

        if self.chunked is not None:
            request.enable_chunked_encoding()
            request.add_chunking_filter(self.chunked)

        # set default content-type
        if (
            self.method in self.POST_METHODS
            and hdrs.CONTENT_TYPE not in self.skip_auto_headers
            and hdrs.CONTENT_TYPE not in self.headers
        ):
            self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream"

        for k, value in self.headers.items():
            request.add_header(k, value)
        request.send_headers()

        self._writer = helpers.ensure_future(self.write_bytes(request, reader), loop=self.loop)

        self.response = self.response_class(
            self.method, self.url, self.host, writer=self._writer, continue100=self._continue
        )
        self.response._post_init(self.loop)
        return self.response

    @asyncio.coroutine
    def close(self):
        if self._writer is not None:
            try:
                yield from self._writer
            finally:
                self._writer = None

    def terminate(self):
        if self._writer is not None:
            if hasattr(self.loop, "is_closed"):
                if not self.loop.is_closed():
                    self._writer.cancel()
            else:
                self._writer.cancel()
            self._writer = None
Exemplo n.º 24
0
    def parse_headers(self, lines):
        """Parses RFC 5322 headers from a stream.

        Line continuations are supported. Returns list of header name
        and value pairs. Header name is in upper case.
        """
        headers = CIMultiDict()
        raw_headers = []

        lines_idx = 1
        line = lines[1]
        line_count = len(lines)

        while line:
            # Parse initial header name : value pair.
            try:
                bname, bvalue = line.split(b':', 1)
            except ValueError:
                raise InvalidHeader(line) from None

            bname = bname.strip(b' \t')
            bvalue = bvalue.lstrip()
            if HDRRE.search(bname):
                raise InvalidHeader(bname)
            if len(bname) > self.max_field_size:
                raise LineTooLong(
                    "request header name {}".format(
                        bname.decode("utf8", "xmlcharrefreplace")),
                    self.max_field_size, len(bname))

            header_length = len(bvalue)

            # next line
            lines_idx += 1
            line = lines[lines_idx]

            # consume continuation lines
            continuation = line and line[0] in (32, 9)  # (' ', '\t')

            if continuation:
                bvalue = [bvalue]
                while continuation:
                    header_length += len(line)
                    if header_length > self.max_field_size:
                        raise LineTooLong(
                            'request header field {}'.format(
                                bname.decode("utf8", "xmlcharrefreplace")),
                            self.max_field_size, header_length)
                    bvalue.append(line)

                    # next line
                    lines_idx += 1
                    if lines_idx < line_count:
                        line = lines[lines_idx]
                        if line:
                            continuation = line[0] in (32, 9)  # (' ', '\t')
                    else:
                        line = b''
                        break
                bvalue = b''.join(bvalue)
            else:
                if header_length > self.max_field_size:
                    raise LineTooLong(
                        'request header field {}'.format(
                            bname.decode("utf8", "xmlcharrefreplace")),
                        self.max_field_size, header_length)

            bvalue = bvalue.strip()
            name = bname.decode('utf-8', 'surrogateescape')
            value = bvalue.decode('utf-8', 'surrogateescape')

            headers.add(name, value)
            raw_headers.append((bname, bvalue))

        close_conn = None
        encoding = None
        upgrade = False
        chunked = False
        raw_headers = tuple(raw_headers)
        headers = CIMultiDictProxy(headers)

        # keep-alive
        conn = headers.get(hdrs.CONNECTION)
        if conn:
            v = conn.lower()
            if v == 'close':
                close_conn = True
            elif v == 'keep-alive':
                close_conn = False
            elif v == 'upgrade':
                upgrade = True

        # encoding
        enc = headers.get(hdrs.CONTENT_ENCODING)
        if enc:
            enc = enc.lower()
            if enc in ('gzip', 'deflate', 'br'):
                encoding = enc

        # chunking
        te = headers.get(hdrs.TRANSFER_ENCODING)
        if te and 'chunked' in te.lower():
            chunked = True

        return headers, raw_headers, close_conn, encoding, upgrade, chunked
Exemplo n.º 25
0
    async def handleRequest(self, reader, writer):
        # connection id, for debugging
        connid = self.nextConnId
        self.nextConnId += 1

        logger.debug(f'{connid}: new incoming connection')
        # simple HTTP parsing, yes this is a terrible idea.
        l = await reader.readline()
        if l == bytes():
            raise BadRequest(f'{connid}: unexpected eof')
        l = l.rstrip(b'\r\n')
        try:
            method, rawPath, proto = l.split(b' ')
            logger.debug(f'{connid}: got {method} {rawPath} {proto}')
        except ValueError:
            logger.error(f'{connid}: cannot split line {l}')
            raise
        reqUrl = furl(rawPath.decode('utf-8'))

        headers = CIMultiDict()
        while True:
            if len(headers) > 100:
                raise BadRequest('too many headers')

            l = await reader.readline()
            if l == bytes():
                raise BadRequest(f'{connid}: unexpected eof in headers')
            l = l.rstrip(b'\r\n')
            logger.debug(f'{connid}: got header line {l!r}')
            # end of headers?
            if l == bytes():
                break
            try:
                key, value = l.decode('utf-8').split(':', 1)
                headers.add(key.strip(), value.strip())
            except ValueError:
                logger.error(f'cannot parse {l}')

        logger.debug(f'{connid}: {rawPath} {method} got headers {headers}')

        route = None
        try:
            netloc = headers['host']
            reqUrl = reqUrl.set(scheme='http', netloc=netloc)
            logger.debug(f'got request url {reqUrl}')
            routeKey = None
            for d in self.domain:
                m = parse(d, reqUrl.netloc)
                if m is not None:
                    routeKey = RouteKey(key=m['key'], user=m['user'])
                    logger.debug(f'{connid}: got route key {routeKey}')
                    break
            route = self.routes[routeKey]
        except (KeyError, ValueError):
            logger.info(f'{connid}: cannot find route for {reqUrl}')
            self.status['noroute'] += 1
            # error is written to client later

        # is this a non-forwarded request?
        segments = reqUrl.path.segments
        if len(segments) > 0 and segments[0] == '_conductor':
            if segments[1] == 'auth':
                logger.info(f'authorization request for {reqUrl.netloc}')
                try:
                    nextLoc = reqUrl.query.params['next'].encode('utf-8')
                except KeyError:
                    nextLoc = b'/'
                writer.write(b'\r\n'.join([
                    b'HTTP/1.0 302 Found', b'Location: ' + nextLoc,
                    b'Set-Cookie: authorization=' +
                    segments[2].encode('utf-8') + b'; HttpOnly; Path=/',
                    b'Cache-Control: no-store', b'',
                    b'Follow the white rabbit.'
                ]))
            elif segments[1] == 'status':
                writer.write(
                    b'HTTP/1.0 200 OK\r\nContent-Type: application/json\r\n\r\n'
                )
                self.status['routesTotal'] = len(self.routes)
                writer.write(
                    json.dumps(self.status, ensure_ascii=True).encode('ascii'))
            else:
                writer.write(
                    b'HTTP/1.0 404 Not Found\r\nContent-Type: plain/text\r\n\r\nNot found'
                )
            writer.close()
            return

        if not route:
            writer.write(
                b'HTTP/1.0 404 Not Found\r\nConnection: close\r\n\r\n')
            writer.close()
            return

        # check authorization
        cookies = BaseCookie()
        try:
            cookies.load(headers['Cookie'])
        except KeyError:
            # will be rejected later
            pass
        authorized = False
        for c in cookies.values():
            # Only hashed authorization is available to server.
            if c.key == 'authorization' and self.hashKey(
                    c.value) == route.auth:
                authorized = True
                break
        try:
            # do not forward auth cookie to the application, so it can’t leak it.
            del cookies['authorization']
            headers['Cookie'] = cookies.output(header='', sep='')
        except KeyError:
            # nonexistent cookie is fine
            pass

        if not authorized:
            logger.info(
                f'{connid}-{reqUrl}: not authorized, cookies sent {cookies.values()}'
            )
            writer.write(
                b'HTTP/1.0 403 Unauthorized\r\nContent-Type: plain/text\r\nConnection: close\r\n\r\nUnauthorized'
            )
            writer.close()
            self.status['unauthorized'] += 1
            return

        # try opening the socket
        try:
            start = time.time()
            sockreader, sockwriter = await asyncio.open_unix_connection(
                path=route.socket)
            end = time.time()
            logger.debug(f'opening socket took {end-start}s')
        except (ConnectionRefusedError, FileNotFoundError, PermissionError):
            logger.info(f'{connid}-{reqUrl}: route {routeKey} is broken')
            writer.write(
                b'HTTP/1.0 502 Bad Gateway\r\nConnection: close\r\n\r\n')
            writer.close()
            self.status['broken'] += 1
            return

        # some headers are fixed
        # not parsing body, so we cannot handle more than one request per connection
        # XXX: this is super-inefficient
        if 'Upgrade' not in headers or headers['Upgrade'].lower(
        ) != 'websocket':
            headers['Connection'] = 'close'

        # write http banner plus headers
        sockwriter.write(method + b' ' + rawPath + b' ' + proto + b'\r\n')
        for k, v in headers.items():
            sockwriter.write(f'{k}: {v}\r\n'.encode('utf-8'))
        sockwriter.write(b'\r\n')

        async def beforeABClose(result):
            if result == 0:
                # no response received from client
                logger.info(
                    f'{connid}-{reqUrl}: route {routeKey} got no result from server'
                )
                writer.write(
                    b'HTTP/1.0 502 Bad Gateway\r\nConnection: close\r\n\r\n')

        await proxy((sockreader, sockwriter, 'sock'), (reader, writer, 'web'),
                    logger=logger,
                    logPrefix=connid,
                    beforeABClose=beforeABClose)
Exemplo n.º 26
0
class ClientRequest:

    GET_METHODS = {hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS}
    POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    ALL_METHODS = GET_METHODS.union(POST_METHODS).union(
        {hdrs.METH_DELETE, hdrs.METH_TRACE})

    DEFAULT_HEADERS = {
        hdrs.ACCEPT: '*/*',
        hdrs.ACCEPT_ENCODING: 'gzip, deflate',
    }

    SERVER_SOFTWARE = HttpMessage.SERVER_SOFTWARE

    body = b''
    auth = None
    response = None
    response_class = None

    _writer = None  # async task for streaming data
    _continue = None  # waiter future for '100 Continue' response

    # N.B.
    # Adding __del__ method with self._writer closing doesn't make sense
    # because _writer is instance method, thus it keeps a reference to self.
    # Until writer has finished finalizer will not be called.

    def __init__(self,
                 method,
                 url,
                 *,
                 params=None,
                 headers=None,
                 skip_auto_headers=frozenset(),
                 data=None,
                 cookies=None,
                 auth=None,
                 encoding='utf-8',
                 version=http.HttpVersion11,
                 compress=None,
                 chunked=None,
                 expect100=False,
                 loop=None,
                 response_class=None,
                 proxy=None,
                 proxy_auth=None,
                 timer=None):

        if loop is None:
            loop = asyncio.get_event_loop()

        assert isinstance(url, URL), url
        assert isinstance(proxy, (URL, type(None))), proxy

        if params:
            q = MultiDict(url.query)
            url2 = url.with_query(params)
            q.extend(url2.query)
            url = url.with_query(q)
        self.url = url.with_fragment(None)
        self.original_url = url
        self.method = method.upper()
        self.encoding = encoding
        self.chunked = chunked
        self.compress = compress
        self.loop = loop
        self.response_class = response_class or ClientResponse
        self._timer = timer if timer is not None else _TimeServiceTimeoutNoop()

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        self.update_version(version)
        self.update_host(url)
        self.update_headers(headers)
        self.update_auto_headers(skip_auto_headers)
        self.update_cookies(cookies)
        self.update_content_encoding(data)
        self.update_auth(auth)
        self.update_proxy(proxy, proxy_auth)

        self.update_body_from_data(data, skip_auto_headers)
        self.update_transfer_encoding()
        self.update_expect_continue(expect100)

    @property
    def host(self):
        return self.url.host

    @property
    def port(self):
        return self.url.port

    def update_host(self, url):
        """Update destination host, port and connection type (ssl)."""
        # get host/port
        if not url.host:
            raise ValueError('Host could not be detected.')

        # basic auth info
        username, password = url.user, url.password
        if username:
            self.auth = helpers.BasicAuth(username, password or '')

        # Record entire netloc for usage in host header

        scheme = url.scheme
        self.ssl = scheme in ('https', 'wss')

    def update_version(self, version):
        """Convert request version to two elements tuple.

        parser HTTP version '1.1' => (1, 1)
        """
        if isinstance(version, str):
            v = [l.strip() for l in version.split('.', 1)]
            try:
                version = int(v[0]), int(v[1])
            except ValueError:
                raise ValueError(
                    'Can not parse http version number: {}'.format(
                        version)) from None
        self.version = version

    def update_headers(self, headers):
        """Update request headers."""
        self.headers = CIMultiDict()
        if headers:
            if isinstance(headers, (dict, MultiDictProxy, MultiDict)):
                headers = headers.items()

            for key, value in headers:
                self.headers.add(key, value)

    def update_auto_headers(self, skip_auto_headers):
        self.skip_auto_headers = skip_auto_headers
        used_headers = set(self.headers) | skip_auto_headers

        for hdr, val in self.DEFAULT_HEADERS.items():
            if hdr not in used_headers:
                self.headers.add(hdr, val)

        # add host
        if hdrs.HOST not in used_headers:
            netloc = self.url.raw_host
            if not self.url.is_default_port():
                netloc += ':' + str(self.url.port)
            self.headers[hdrs.HOST] = netloc

        if hdrs.USER_AGENT not in used_headers:
            self.headers[hdrs.USER_AGENT] = self.SERVER_SOFTWARE

    def update_cookies(self, cookies):
        """Update request cookies header."""
        if not cookies:
            return

        c = SimpleCookie()
        if hdrs.COOKIE in self.headers:
            c.load(self.headers.get(hdrs.COOKIE, ''))
            del self.headers[hdrs.COOKIE]

        for name, value in cookies.items():
            if isinstance(value, Morsel):
                # Preserve coded_value
                mrsl_val = value.get(value.key, Morsel())
                mrsl_val.set(value.key, value.value, value.coded_value)
                c[name] = mrsl_val
            else:
                c[name] = value

        self.headers[hdrs.COOKIE] = c.output(header='', sep=';').strip()

    def update_content_encoding(self, data):
        """Set request content encoding."""
        if not data:
            return

        enc = self.headers.get(hdrs.CONTENT_ENCODING, '').lower()
        if enc:
            if self.compress is not False:
                self.compress = enc
                # enable chunked, no need to deal with length
                self.chunked = True
        elif self.compress:
            if not isinstance(self.compress, str):
                self.compress = 'deflate'
            self.headers[hdrs.CONTENT_ENCODING] = self.compress
            self.chunked = True  # enable chunked, no need to deal with length

    def update_auth(self, auth):
        """Set basic auth."""
        if auth is None:
            auth = self.auth
        if auth is None:
            return

        if not isinstance(auth, helpers.BasicAuth):
            raise TypeError('BasicAuth() tuple is required instead')

        self.headers[hdrs.AUTHORIZATION] = auth.encode()

    def update_body_from_data(self, data, skip_auto_headers):
        if not data:
            return

        if isinstance(data, str):
            data = data.encode(self.encoding)

        if isinstance(data, (bytes, bytearray)):
            self.body = data
            if (hdrs.CONTENT_TYPE not in self.headers
                    and hdrs.CONTENT_TYPE not in skip_auto_headers):
                self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'
            if hdrs.CONTENT_LENGTH not in self.headers and not self.chunked:
                self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

        elif isinstance(
                data,
            (asyncio.StreamReader, streams.StreamReader, streams.DataQueue)):
            self.body = data

        elif asyncio.iscoroutine(data):
            self.body = data
            if (hdrs.CONTENT_LENGTH not in self.headers
                    and self.chunked is None):
                self.chunked = True

        elif isinstance(data, io.IOBase):
            assert not isinstance(data, io.StringIO), \
                'attempt to send text data instead of binary'
            self.body = data
            if not self.chunked and isinstance(data, io.BytesIO):
                # Not chunking if content-length can be determined
                size = len(data.getbuffer())
                self.headers[hdrs.CONTENT_LENGTH] = str(size)
                self.chunked = False
            elif (not self.chunked
                  and isinstance(data,
                                 (io.BufferedReader, io.BufferedRandom))):
                # Not chunking if content-length can be determined
                try:
                    size = os.fstat(data.fileno()).st_size - data.tell()
                    self.headers[hdrs.CONTENT_LENGTH] = str(size)
                    self.chunked = False
                except OSError:
                    # data.fileno() is not supported, e.g.
                    # io.BufferedReader(io.BytesIO(b'data'))
                    self.chunked = True
            else:
                self.chunked = True

            if hasattr(data, 'mode'):
                if data.mode == 'r':
                    raise ValueError('file {!r} should be open in binary mode'
                                     ''.format(data))
            if (hdrs.CONTENT_TYPE not in self.headers
                    and hdrs.CONTENT_TYPE not in skip_auto_headers
                    and hasattr(data, 'name')):
                mime = mimetypes.guess_type(data.name)[0]
                mime = 'application/octet-stream' if mime is None else mime
                self.headers[hdrs.CONTENT_TYPE] = mime

        elif isinstance(data, MultipartWriter):
            self.body = data.serialize()
            self.headers.update(data.headers)
            self.chunked = True

        else:
            if not isinstance(data, helpers.FormData):
                data = helpers.FormData(data)

            self.body = data(self.encoding)

            if (hdrs.CONTENT_TYPE not in self.headers
                    and hdrs.CONTENT_TYPE not in skip_auto_headers):
                self.headers[hdrs.CONTENT_TYPE] = data.content_type

            if data.is_multipart:
                self.chunked = True
            else:
                if (hdrs.CONTENT_LENGTH not in self.headers
                        and not self.chunked):
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_transfer_encoding(self):
        """Analyze transfer-encoding header."""
        te = self.headers.get(hdrs.TRANSFER_ENCODING, '').lower()

        if self.chunked:
            if hdrs.CONTENT_LENGTH in self.headers:
                del self.headers[hdrs.CONTENT_LENGTH]
            if 'chunked' not in te:
                self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'

        else:
            if 'chunked' in te:
                self.chunked = True
            else:
                self.chunked = None
                if hdrs.CONTENT_LENGTH not in self.headers:
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_expect_continue(self, expect=False):
        if expect:
            self.headers[hdrs.EXPECT] = '100-continue'
        elif self.headers.get(hdrs.EXPECT, '').lower() == '100-continue':
            expect = True

        if expect:
            self._continue = helpers.create_future(self.loop)

    def update_proxy(self, proxy, proxy_auth):
        if proxy and not proxy.scheme == 'http':
            raise ValueError("Only http proxies are supported")
        if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
            raise ValueError("proxy_auth must be None or BasicAuth() tuple")
        self.proxy = proxy
        self.proxy_auth = proxy_auth

    @asyncio.coroutine
    def write_bytes(self, request, conn):
        """Support coroutines that yields bytes objects."""
        # 100 response
        if self._continue is not None:
            yield from request.drain()
            yield from self._continue

        try:
            if asyncio.iscoroutine(self.body):
                exc = None
                value = None
                stream = self.body

                while True:
                    try:
                        if exc is not None:
                            result = stream.throw(exc)
                        else:
                            result = stream.send(value)
                    except StopIteration as exc:
                        if isinstance(exc.value, bytes):
                            yield from request.write(exc.value)
                        break
                    except:
                        self.response.close()
                        raise

                    if isinstance(result, asyncio.Future):
                        exc = None
                        value = None
                        try:
                            value = yield result
                        except Exception as err:
                            exc = err
                    elif isinstance(result, (bytes, bytearray)):
                        yield from request.write(result)
                        value = None
                    else:
                        raise ValueError('Bytes object is expected, got: %s.' %
                                         type(result))

            elif isinstance(self.body,
                            (asyncio.StreamReader, streams.StreamReader)):
                chunk = yield from self.body.read(streams.DEFAULT_LIMIT)
                while chunk:
                    yield from request.write(chunk, drain=True)
                    chunk = yield from self.body.read(streams.DEFAULT_LIMIT)

            elif isinstance(self.body, streams.DataQueue):
                while True:
                    try:
                        chunk = yield from self.body.read()
                        if not chunk:
                            break
                        yield from request.write(chunk)
                    except streams.EofStream:
                        break

            elif isinstance(self.body, io.IOBase):
                chunk = self.body.read(streams.DEFAULT_LIMIT)
                while chunk:
                    request.write(chunk)
                    chunk = self.body.read(self.chunked)
            else:
                if isinstance(self.body, (bytes, bytearray)):
                    self.body = (self.body, )

                for chunk in self.body:
                    request.write(chunk)

        except Exception as exc:
            new_exc = aiohttp.ClientRequestError(
                'Can not write request body for %s' % self.url)
            new_exc.__context__ = exc
            new_exc.__cause__ = exc
            conn.protocol.set_exception(new_exc)
        else:
            try:
                yield from request.write_eof()
            except Exception as exc:
                new_exc = aiohttp.ClientRequestError(
                    'Can not write request body for %s' % self.url)
                new_exc.__context__ = exc
                new_exc.__cause__ = exc
                conn.protocol.set_exception(new_exc)

        self._writer = None

    def send(self, conn):
        # Specify request target:
        # - CONNECT request must send authority form URI
        # - not CONNECT proxy must send absolute form URI
        # - most common is origin form URI
        if self.method == hdrs.METH_CONNECT:
            path = '{}:{}'.format(self.url.raw_host, self.url.port)
        elif self.proxy and not self.ssl:
            path = str(self.url)
        else:
            path = self.url.raw_path
            if self.url.raw_query_string:
                path += '?' + self.url.raw_query_string

        request = http.Request(conn.writer,
                               self.method,
                               path,
                               self.version,
                               loop=self.loop)

        if self.compress:
            request.enable_compression(self.compress)

        if self.chunked is not None:
            request.enable_chunking()

        # set default content-type
        if (self.method in self.POST_METHODS
                and hdrs.CONTENT_TYPE not in self.skip_auto_headers
                and hdrs.CONTENT_TYPE not in self.headers):
            self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'

        for k, value in self.headers.items():
            request.add_header(k, value)
        request.send_headers()

        self._writer = helpers.ensure_future(self.write_bytes(request, conn),
                                             loop=self.loop)

        self.response = self.response_class(self.method,
                                            self.original_url,
                                            writer=self._writer,
                                            continue100=self._continue,
                                            timer=self._timer)

        self.response._post_init(self.loop)
        return self.response

    @asyncio.coroutine
    def close(self):
        if self._writer is not None:
            try:
                yield from self._writer
            finally:
                self._writer = None

    def terminate(self):
        if self._writer is not None:
            if not self.loop.is_closed():
                self._writer.cancel()
            self._writer = None
Exemplo n.º 27
0
class ClientRequest:

    GET_METHODS = {hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS}
    POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    ALL_METHODS = GET_METHODS.union(POST_METHODS).union(
        {hdrs.METH_DELETE, hdrs.METH_TRACE})

    DEFAULT_HEADERS = {
        hdrs.ACCEPT: '*/*',
        hdrs.ACCEPT_ENCODING: 'gzip, deflate',
    }

    body = b''
    auth = None
    response = None
    response_class = None

    _writer = None  # async task for streaming data
    _continue = None  # waiter future for '100 Continue' response

    # N.B.
    # Adding __del__ method with self._writer closing doesn't make sense
    # because _writer is instance method, thus it keeps a reference to self.
    # Until writer has finished finalizer will not be called.

    def __init__(self, method, url, *,
                 params=None, headers=None, skip_auto_headers=frozenset(),
                 data=None, cookies=None,
                 auth=None, version=http.HttpVersion11, compress=None,
                 chunked=None, expect100=False,
                 loop=None, response_class=None,
                 proxy=None, proxy_auth=None, proxy_from_env=False,
                 timer=None):

        if loop is None:
            loop = asyncio.get_event_loop()

        assert isinstance(url, URL), url
        assert isinstance(proxy, (URL, type(None))), proxy

        if params:
            q = MultiDict(url.query)
            url2 = url.with_query(params)
            q.extend(url2.query)
            url = url.with_query(q)
        self.url = url.with_fragment(None)
        self.original_url = url
        self.method = method.upper()
        self.chunked = chunked
        self.compress = compress
        self.loop = loop
        self.length = None
        self.response_class = response_class or ClientResponse
        self._timer = timer if timer is not None else TimerNoop()

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        self.update_version(version)
        self.update_host(url)
        self.update_headers(headers)
        self.update_auto_headers(skip_auto_headers)
        self.update_cookies(cookies)
        self.update_content_encoding(data)
        self.update_auth(auth)
        self.update_proxy(proxy, proxy_auth, proxy_from_env)

        self.update_body_from_data(data, skip_auto_headers)
        self.update_transfer_encoding()
        self.update_expect_continue(expect100)

    @property
    def host(self):
        return self.url.host

    @property
    def port(self):
        return self.url.port

    @property
    def request_info(self):
        return RequestInfo(self.url, self.method, self.headers)

    def update_host(self, url):
        """Update destination host, port and connection type (ssl)."""
        # get host/port
        if not url.host:
            raise ValueError(
                "Could not parse hostname from URL '{}'".format(url))

        # basic auth info
        username, password = url.user, url.password
        if username:
            self.auth = helpers.BasicAuth(username, password or '')

        # Record entire netloc for usage in host header

        scheme = url.scheme
        self.ssl = scheme in ('https', 'wss')

    def update_version(self, version):
        """Convert request version to two elements tuple.

        parser HTTP version '1.1' => (1, 1)
        """
        if isinstance(version, str):
            v = [l.strip() for l in version.split('.', 1)]
            try:
                version = int(v[0]), int(v[1])
            except ValueError:
                raise ValueError(
                    'Can not parse http version number: {}'
                    .format(version)) from None
        self.version = version

    def update_headers(self, headers):
        """Update request headers."""
        self.headers = CIMultiDict()
        if headers:
            if isinstance(headers, (dict, MultiDictProxy, MultiDict)):
                headers = headers.items()

            for key, value in headers:
                self.headers.add(key, value)

    def update_auto_headers(self, skip_auto_headers):
        self.skip_auto_headers = skip_auto_headers
        used_headers = set(self.headers) | skip_auto_headers

        for hdr, val in self.DEFAULT_HEADERS.items():
            if hdr not in used_headers:
                self.headers.add(hdr, val)

        # add host
        if hdrs.HOST not in used_headers:
            netloc = self.url.raw_host
            if not self.url.is_default_port():
                netloc += ':' + str(self.url.port)
            self.headers[hdrs.HOST] = netloc

        if hdrs.USER_AGENT not in used_headers:
            self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE

    def update_cookies(self, cookies):
        """Update request cookies header."""
        if not cookies:
            return

        c = SimpleCookie()
        if hdrs.COOKIE in self.headers:
            c.load(self.headers.get(hdrs.COOKIE, ''))
            del self.headers[hdrs.COOKIE]

        for name, value in cookies.items():
            if isinstance(value, Morsel):
                # Preserve coded_value
                mrsl_val = value.get(value.key, Morsel())
                mrsl_val.set(value.key, value.value, value.coded_value)
                c[name] = mrsl_val
            else:
                c[name] = value

        self.headers[hdrs.COOKIE] = c.output(header='', sep=';').strip()

    def update_content_encoding(self, data):
        """Set request content encoding."""
        if not data:
            return

        enc = self.headers.get(hdrs.CONTENT_ENCODING, '').lower()
        if enc:
            if self.compress:
                raise ValueError(
                    'compress can not be set '
                    'if Content-Encoding header is set')
        elif self.compress:
            if not isinstance(self.compress, str):
                self.compress = 'deflate'
            self.headers[hdrs.CONTENT_ENCODING] = self.compress
            self.chunked = True  # enable chunked, no need to deal with length

    def update_transfer_encoding(self):
        """Analyze transfer-encoding header."""
        te = self.headers.get(hdrs.TRANSFER_ENCODING, '').lower()

        if 'chunked' in te:
            if self.chunked:
                raise ValueError(
                    'chunked can not be set '
                    'if "Transfer-Encoding: chunked" header is set')

        elif self.chunked:
            if hdrs.CONTENT_LENGTH in self.headers:
                raise ValueError(
                    'chunked can not be set '
                    'if Content-Length header is set')

            self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'
        else:
            if hdrs.CONTENT_LENGTH not in self.headers:
                self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_auth(self, auth):
        """Set basic auth."""
        if auth is None:
            auth = self.auth
        if auth is None:
            return

        if not isinstance(auth, helpers.BasicAuth):
            raise TypeError('BasicAuth() tuple is required instead')

        self.headers[hdrs.AUTHORIZATION] = auth.encode()

    def update_body_from_data(self, body, skip_auto_headers):
        if not body:
            return

        # FormData
        if isinstance(body, FormData):
            body = body()

        try:
            body = payload.PAYLOAD_REGISTRY.get(body, disposition=None)
        except payload.LookupError:
            body = FormData(body)()

        self.body = body

        # enable chunked encoding if needed
        if not self.chunked:
            if hdrs.CONTENT_LENGTH not in self.headers:
                size = body.size
                if size is None:
                    self.chunked = True
                else:
                    if hdrs.CONTENT_LENGTH not in self.headers:
                        self.headers[hdrs.CONTENT_LENGTH] = str(size)

        # set content-type
        if (hdrs.CONTENT_TYPE not in self.headers and
                hdrs.CONTENT_TYPE not in skip_auto_headers):
            self.headers[hdrs.CONTENT_TYPE] = body.content_type

        # copy payload headers
        if body.headers:
            for (key, value) in body.headers.items():
                if key not in self.headers:
                    self.headers[key] = value

    def update_expect_continue(self, expect=False):
        if expect:
            self.headers[hdrs.EXPECT] = '100-continue'
        elif self.headers.get(hdrs.EXPECT, '').lower() == '100-continue':
            expect = True

        if expect:
            self._continue = helpers.create_future(self.loop)

    def update_proxy(self, proxy, proxy_auth, proxy_from_env):
        if proxy_from_env and not proxy:
            proxy_url = getproxies().get(self.original_url.scheme)
            proxy = URL(proxy_url) if proxy_url else None
        if proxy and not proxy.scheme == 'http':
            raise ValueError("Only http proxies are supported")
        if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
            raise ValueError("proxy_auth must be None or BasicAuth() tuple")
        self.proxy = proxy
        self.proxy_auth = proxy_auth

    def keep_alive(self):
        if self.version < HttpVersion10:
            # keep alive not supported at all
            return False
        if self.version == HttpVersion10:
            if self.headers.get(hdrs.CONNECTION) == 'keep-alive':
                return True
            else:  # no headers means we close for Http 1.0
                return False
        elif self.headers.get(hdrs.CONNECTION) == 'close':
            return False

        return True

    @asyncio.coroutine
    def write_bytes(self, writer, conn):
        """Support coroutines that yields bytes objects."""
        # 100 response
        if self._continue is not None:
            yield from writer.drain()
            yield from self._continue

        try:
            if isinstance(self.body, payload.Payload):
                yield from self.body.write(writer)
            else:
                if isinstance(self.body, (bytes, bytearray)):
                    self.body = (self.body,)

                for chunk in self.body:
                    writer.write(chunk)

            yield from writer.write_eof()
        except OSError as exc:
            new_exc = ClientOSError(
                exc.errno,
                'Can not write request body for %s' % self.url)
            new_exc.__context__ = exc
            new_exc.__cause__ = exc
            conn.protocol.set_exception(new_exc)
        except Exception as exc:
            conn.protocol.set_exception(exc)
        finally:
            self._writer = None

    def send(self, conn):
        # Specify request target:
        # - CONNECT request must send authority form URI
        # - not CONNECT proxy must send absolute form URI
        # - most common is origin form URI
        if self.method == hdrs.METH_CONNECT:
            path = '{}:{}'.format(self.url.raw_host, self.url.port)
        elif self.proxy and not self.ssl:
            path = str(self.url)
        else:
            path = self.url.raw_path
            if self.url.raw_query_string:
                path += '?' + self.url.raw_query_string

        writer = PayloadWriter(conn.writer, self.loop)

        if self.compress:
            writer.enable_compression(self.compress)

        if self.chunked is not None:
            writer.enable_chunking()

        # set default content-type
        if (self.method in self.POST_METHODS and
                hdrs.CONTENT_TYPE not in self.skip_auto_headers and
                hdrs.CONTENT_TYPE not in self.headers):
            self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'

        # set the connection header
        connection = self.headers.get(hdrs.CONNECTION)
        if not connection:
            if self.keep_alive():
                if self.version == HttpVersion10:
                    connection = 'keep-alive'
            else:
                if self.version == HttpVersion11:
                    connection = 'close'

        if connection is not None:
            self.headers[hdrs.CONNECTION] = connection

        # status + headers
        status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format(
            self.method, path, self.version)
        writer.write_headers(status_line, self.headers)

        self._writer = helpers.ensure_future(
            self.write_bytes(writer, conn), loop=self.loop)

        self.response = self.response_class(
            self.method, self.original_url,
            writer=self._writer, continue100=self._continue, timer=self._timer,
            request_info=self.request_info
        )

        self.response._post_init(self.loop)
        return self.response

    @asyncio.coroutine
    def close(self):
        if self._writer is not None:
            try:
                yield from self._writer
            finally:
                self._writer = None

    def terminate(self):
        if self._writer is not None:
            if not self.loop.is_closed():
                self._writer.cancel()
            self._writer = None
Exemplo n.º 28
0
def handlerWebApp(request):
    objResp = web.Response()
    objRespHead = CIMultiDict()
    proxyHeader = {}
    proxyCookie = {}

    for var_head, var_value in request.headers.items():
        if var_head == "Host":
            proxyHeader[var_head] = g_proxy_site_no_prefix_http
        elif var_head == "Origin":
            proxyHeader[var_head] = var_value.replace(
                g_str_dns, g_proxy_site_no_prefix_http)
        elif var_head == "Referer":
            proxyHeader[var_head] = var_value.replace(
                g_str_dns, g_proxy_site_no_prefix_http)
            # proxyHeader[var_head[0]] = var_head[1].replace("127.0.0.1:8888", "zxyxxni.tender88.com")
        elif var_head == "Upgrade-Insecure-Requests":
            proxyHeader[var_head] = var_value
        else:
            proxyHeader[var_head] = var_value

    listCookie = request.headers.get('cookie', None)
    if listCookie is not None:
        listCookie = listCookie.split(";")
        for var_cookie in listCookie:
            key = var_cookie.split('=')[0].strip()
            value = var_cookie.split('=')[1].strip()

            tempDns = g_str_dns[g_str_dns.find('.'):]
            tempIndex = tempDns.find(":")
            if tempIndex >= 0:
                tempDns = tempDns[:tempIndex]

            value = value.replace("Domain={}".format(tempDns),
                                  "Domain=.tender88.com")
            proxyCookie[key] = value

    print(proxyCookie)
    print("############### cookie end ###############")
    print(proxyHeader)
    print("################ header end #####################")
    ret = yield from request.read()

    #beginTime = time.time()
    print("################ request begin ###############")

    bBlockError = False
    if request.path not in g_hook_url:
        body = b''
        if request.method == "GET":
            print("#######get url [{}]".format(request.raw_path))
            if request.raw_path.find("SWJIYLWA="):
                #这个是跨域报错的返回:
                bBlockError = True

            with aiohttp.ClientSession(cookies=proxyCookie) as session:
                if len(ret) <= 0:
                    resp = yield from session.get(g_proxy_site +
                                                  request.raw_path,
                                                  headers=proxyHeader,
                                                  verify_ssl=False)
                else:
                    resp = yield from session.get(g_proxy_site +
                                                  request.raw_path,
                                                  data=ret,
                                                  headers=proxyHeader,
                                                  verify_ssl=False)
                #print("################ read begin ###############")
                body = yield from resp.read()
                #
                # if bBlockError:
                #     if body.startswith(b'(function() { var z="";var'):
                #         body = b''
                #print("################ read end ###############")
        else:
            print("#########post url [{}]".format(request.raw_path))
            with aiohttp.ClientSession(cookies=proxyCookie) as session:
                resp = yield from session.post(g_proxy_site + request.raw_path,
                                               data=ret,
                                               headers=proxyHeader,
                                               verify_ssl=False)
                #print("################ read begin ###############")
                body = yield from resp.read()
                #print("################ read end ###############")

        objResp._status = resp.status
        print("^^^^^^^^^^^^^^^^^^^^^^^^^^")
        print(resp.status)

        for head_key, head_value in resp.headers.items():
            if head_key == "Content-Encoding":
                continue
            elif head_key == "Content-Length":
                continue
            elif head_key == "Transfer-Encoding":
                continue
            elif head_key == "Set-Cookie":
                #print("收到返回的cookie")
                #print(head_value)

                tempDns = g_str_dns[g_str_dns.find('.'):]
                tempIndex = tempDns.find(":")
                if tempIndex >= 0:
                    tempDns = tempDns[:tempIndex]

                objRespHead.add(
                    head_key,
                    head_value.replace("Domain=.tender88.com",
                                       "Domain={}".format(tempDns)))
            else:
                objRespHead.add(head_key, head_value)

        if request.path in g_hook_result:
            if (g_hook_result[request.path]["content"]) == body:
                body = g_hook_result[request.path]["func"](body)

        objResp.body = body

    else:
        objResp._status = 200
        objHead, objResp.body = yield from g_hook_url[request.path]["func"](
            request, g_hook_url[request.path]["header"])
        objRespHead.extend(objHead)

    #print("################ request end ###############")
    #print(time.time() - beginTime)
    objRespHead.add("Access-Control-Allow-Origin", "*")
    objResp.headers.extend(objRespHead)

    print("####### return code[{}] head[{}] #######".format(200, objRespHead))
    return objResp
Exemplo n.º 29
0
class HttpMessage(PayloadWriter):
    """HttpMessage allows to write headers and payload to a stream."""

    HOP_HEADERS = None  # Must be set by subclass.

    SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} aiohttp/{1}'.format(
        sys.version_info, aiohttp.__version__)

    upgrade = False  # Connection: UPGRADE
    websocket = False  # Upgrade: WEBSOCKET
    has_chunked_hdr = False  # Transfer-encoding: chunked

    def __init__(self, transport, version, close, loop=None):
        super().__init__(transport, loop)

        self.version = version
        self.closing = close
        self.keepalive = None
        self.length = None
        self.headers = CIMultiDict()
        self.headers_sent = False

    @property
    def body_length(self):
        return self.output_length

    def force_close(self):
        self.closing = True
        self.keepalive = False

    def keep_alive(self):
        if self.keepalive is None:
            if self.version < HttpVersion10:
                # keep alive not supported at all
                return False
            if self.version == HttpVersion10:
                if self.headers.get(hdrs.CONNECTION) == 'keep-alive':
                    return True
                else:  # no headers means we close for Http 1.0
                    return False
            else:
                return not self.closing
        else:
            return self.keepalive

    def is_headers_sent(self):
        return self.headers_sent

    def add_header(self, name, value):
        """Analyze headers. Calculate content length,
        removes hop headers, etc."""
        assert not self.headers_sent, 'headers have been sent already'
        assert isinstance(name, str), \
            'Header name should be a string, got {!r}'.format(name)
        assert set(name).issubset(ASCIISET), \
            'Header name should contain ASCII chars, got {!r}'.format(name)
        assert isinstance(value, str), \
            'Header {!r} should have string value, got {!r}'.format(
                name, value)

        name = istr(name)
        value = value.strip()

        if name == hdrs.CONTENT_LENGTH:
            self.length = int(value)

        if name == hdrs.TRANSFER_ENCODING:
            self.has_chunked_hdr = value.lower() == 'chunked'

        if name == hdrs.CONNECTION:
            val = value.lower()
            # handle websocket
            if 'upgrade' in val:
                self.upgrade = True
            # connection keep-alive
            elif 'close' in val:
                self.keepalive = False
            elif 'keep-alive' in val:
                self.keepalive = True

        elif name == hdrs.UPGRADE:
            if 'websocket' in value.lower():
                self.websocket = True
            self.headers[name] = value

        elif name not in self.HOP_HEADERS:
            # ignore hop-by-hop headers
            self.headers.add(name, value)

    def add_headers(self, *headers):
        """Adds headers to a HTTP message."""
        for name, value in headers:
            self.add_header(name, value)

    def send_headers(self, _sep=': ', _end='\r\n'):
        """Writes headers to a stream. Constructs payload writer."""
        # Chunked response is only for HTTP/1.1 clients or newer
        # and there is no Content-Length header is set.
        # Do not use chunked responses when the response is guaranteed to
        # not have a response body (304, 204).
        assert not self.headers_sent, 'headers have been sent already'
        self.headers_sent = True

        if not self.chunked and self.autochunked():
            self.enable_chunking()

        if self.chunked:
            self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'

        self._add_default_headers()

        # status + headers
        headers = self.status_line + ''.join(
            [k + _sep + v + _end for k, v in self.headers.items()])
        headers = headers.encode('utf-8') + b'\r\n'

        self.buffer_data(headers)

    def _add_default_headers(self):
        # set the connection header
        connection = None
        if self.upgrade:
            connection = 'Upgrade'
        elif not self.closing if self.keepalive is None else self.keepalive:
            if self.version == HttpVersion10:
                connection = 'keep-alive'
        else:
            if self.version == HttpVersion11:
                connection = 'close'

        if connection is not None:
            self.headers[hdrs.CONNECTION] = connection
Exemplo n.º 30
0
class HttpMessage(ABC):
    """HttpMessage allows to write headers and payload to a stream.

    For example, lets say we want to read file then compress it with deflate
    compression and then send it with chunked transfer encoding, code may look
    like this:

       >>> response = aiohttp.Response(transport, 200)

    We have to use deflate compression first:

      >>> response.add_compression_filter('deflate')

    Then we want to split output stream into chunks of 1024 bytes size:

      >>> response.add_chunking_filter(1024)

    We can add headers to response with add_headers() method. add_headers()
    does not send data to transport, send_headers() sends request/response
    line and then sends headers:

      >>> response.add_headers(
      ...     ('Content-Disposition', 'attachment; filename="..."'))
      >>> response.send_headers()

    Now we can use chunked writer to write stream to a network stream.
    First call to write() method sends response status line and headers,
    add_header() and add_headers() method unavailable at this stage:

    >>> with open('...', 'rb') as f:
    ...     chunk = fp.read(8192)
    ...     while chunk:
    ...         response.write(chunk)
    ...         chunk = fp.read(8192)

    >>> response.write_eof()

    """

    writer = None

    # 'filter' is being used for altering write() behaviour,
    # add_chunking_filter adds deflate/gzip compression and
    # add_compression_filter splits incoming data into a chunks.
    filter = None

    HOP_HEADERS = None  # Must be set by subclass.

    SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} aiohttp/{1}'.format(
        sys.version_info, aiohttp.__version__)

    upgrade = False  # Connection: UPGRADE
    websocket = False  # Upgrade: WEBSOCKET
    has_chunked_hdr = False  # Transfer-encoding: chunked

    # subclass can enable auto sending headers with write() call,
    # this is useful for wsgi's start_response implementation.
    _send_headers = False

    def __init__(self, transport, version, close):
        self.transport = transport
        self._version = version
        self.closing = close
        self.keepalive = None
        self.chunked = False
        self.length = None
        self.headers = CIMultiDict()
        self.headers_sent = False
        self.output_length = 0
        self.headers_length = 0
        self._output_size = 0

    @property
    @abstractmethod
    def status_line(self):
        return b''

    @abstractmethod
    def autochunked(self):
        return False

    @property
    def version(self):
        return self._version

    @property
    def body_length(self):
        return self.output_length - self.headers_length

    def force_close(self):
        self.closing = True
        self.keepalive = False

    def enable_chunked_encoding(self):
        self.chunked = True

    def keep_alive(self):
        if self.keepalive is None:
            if self.version < HttpVersion10:
                # keep alive not supported at all
                return False
            if self.version == HttpVersion10:
                if self.headers.get(hdrs.CONNECTION) == 'keep-alive':
                    return True
                else:  # no headers means we close for Http 1.0
                    return False
            else:
                return not self.closing
        else:
            return self.keepalive

    def is_headers_sent(self):
        return self.headers_sent

    def add_header(self, name, value):
        """Analyze headers. Calculate content length,
        removes hop headers, etc."""
        assert not self.headers_sent, 'headers have been sent already'
        assert isinstance(name, str), \
            'Header name should be a string, got {!r}'.format(name)
        assert set(name).issubset(ASCIISET), \
            'Header name should contain ASCII chars, got {!r}'.format(name)
        assert isinstance(value, str), \
            'Header {!r} should have string value, got {!r}'.format(
                name, value)

        name = upstr(name)
        value = value.strip()

        if name == hdrs.CONTENT_LENGTH:
            self.length = int(value)

        if name == hdrs.TRANSFER_ENCODING:
            self.has_chunked_hdr = value.lower().strip() == 'chunked'

        if name == hdrs.CONNECTION:
            val = value.lower()
            # handle websocket
            if 'upgrade' in val:
                self.upgrade = True
            # connection keep-alive
            elif 'close' in val:
                self.keepalive = False
            elif 'keep-alive' in val:
                self.keepalive = True

        elif name == hdrs.UPGRADE:
            if 'websocket' in value.lower():
                self.websocket = True
                self.headers[name] = value

        elif name not in self.HOP_HEADERS:
            # ignore hop-by-hop headers
            self.headers.add(name, value)

    def add_headers(self, *headers):
        """Adds headers to a HTTP message."""
        for name, value in headers:
            self.add_header(name, value)

    def send_headers(self, _sep=': ', _end='\r\n'):
        """Writes headers to a stream. Constructs payload writer."""
        # Chunked response is only for HTTP/1.1 clients or newer
        # and there is no Content-Length header is set.
        # Do not use chunked responses when the response is guaranteed to
        # not have a response body (304, 204).
        assert not self.headers_sent, 'headers have been sent already'
        self.headers_sent = True

        if self.chunked or self.autochunked():
            self.writer = self._write_chunked_payload()
            self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'

        elif self.length is not None:
            self.writer = self._write_length_payload(self.length)

        else:
            self.writer = self._write_eof_payload()

        next(self.writer)

        self._add_default_headers()

        # status + headers
        headers = self.status_line + ''.join(
            [k + _sep + v + _end for k, v in self.headers.items()])
        headers = headers.encode('utf-8') + b'\r\n'

        self.output_length += len(headers)
        self.headers_length = len(headers)
        self.transport.write(headers)

    def _add_default_headers(self):
        # set the connection header
        connection = None
        if self.upgrade:
            connection = 'upgrade'
        elif not self.closing if self.keepalive is None else self.keepalive:
            if self.version == HttpVersion10:
                connection = 'keep-alive'
        else:
            if self.version == HttpVersion11:
                connection = 'close'

        if connection is not None:
            self.headers[hdrs.CONNECTION] = connection

    def write(self,
              chunk,
              *,
              drain=False,
              EOF_MARKER=EOF_MARKER,
              EOL_MARKER=EOL_MARKER):
        """Writes chunk of data to a stream by using different writers.

        writer uses filter to modify chunk of data.
        write_eof() indicates end of stream.
        writer can't be used after write_eof() method being called.
        write() return drain future.
        """
        assert (isinstance(chunk, (bytes, bytearray))
                or chunk is EOF_MARKER), chunk

        size = self.output_length

        if self._send_headers and not self.headers_sent:
            self.send_headers()

        assert self.writer is not None, 'send_headers() is not called.'

        if self.filter:
            chunk = self.filter.send(chunk)
            while chunk not in (EOF_MARKER, EOL_MARKER):
                if chunk:
                    self.writer.send(chunk)
                chunk = next(self.filter)
        else:
            if chunk is not EOF_MARKER:
                self.writer.send(chunk)

        self._output_size += self.output_length - size

        if self._output_size > 64 * 1024:
            if drain:
                self._output_size = 0
                return self.transport.drain()

        return ()

    def write_eof(self):
        self.write(EOF_MARKER)
        try:
            self.writer.throw(aiohttp.EofStream())
        except StopIteration:
            pass

        return self.transport.drain()

    def _write_chunked_payload(self):
        """Write data in chunked transfer encoding."""
        while True:
            try:
                chunk = yield
            except aiohttp.EofStream:
                self.transport.write(b'0\r\n\r\n')
                self.output_length += 5
                break

            chunk = bytes(chunk)
            chunk_len = '{:x}\r\n'.format(len(chunk)).encode('ascii')
            self.transport.write(chunk_len + chunk + b'\r\n')
            self.output_length += len(chunk_len) + len(chunk) + 2

    def _write_length_payload(self, length):
        """Write specified number of bytes to a stream."""
        while True:
            try:
                chunk = yield
            except aiohttp.EofStream:
                break

            if length:
                l = len(chunk)
                if length >= l:
                    self.transport.write(chunk)
                    self.output_length += l
                    length = length - l
                else:
                    self.transport.write(chunk[:length])
                    self.output_length += length
                    length = 0

    def _write_eof_payload(self):
        while True:
            try:
                chunk = yield
            except aiohttp.EofStream:
                break

            self.transport.write(chunk)
            self.output_length += len(chunk)

    @wrap_payload_filter
    def add_chunking_filter(self,
                            chunk_size=16 * 1024,
                            *,
                            EOF_MARKER=EOF_MARKER,
                            EOL_MARKER=EOL_MARKER):
        """Split incoming stream into chunks."""
        buf = bytearray()
        chunk = yield

        while True:
            if chunk is EOF_MARKER:
                if buf:
                    yield buf

                yield EOF_MARKER

            else:
                buf.extend(chunk)

                while len(buf) >= chunk_size:
                    chunk = bytes(buf[:chunk_size])
                    del buf[:chunk_size]
                    yield chunk

                chunk = yield EOL_MARKER

    @wrap_payload_filter
    def add_compression_filter(self,
                               encoding='deflate',
                               *,
                               EOF_MARKER=EOF_MARKER,
                               EOL_MARKER=EOL_MARKER):
        """Compress incoming stream with deflate or gzip encoding."""
        zlib_mode = (16 +
                     zlib.MAX_WBITS if encoding == 'gzip' else -zlib.MAX_WBITS)
        zcomp = zlib.compressobj(wbits=zlib_mode)

        chunk = yield
        while True:
            if chunk is EOF_MARKER:
                yield zcomp.flush()
                chunk = yield EOF_MARKER

            else:
                yield zcomp.compress(chunk)
                chunk = yield EOL_MARKER
Exemplo n.º 31
0
class Class150:
    var3318 = {hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS}
    var2166 = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    var3357 = var3318.union(var2166).union({hdrs.METH_DELETE, hdrs.METH_TRACE})
    var3979 = {hdrs.ACCEPT: '*/*', hdrs.ACCEPT_ENCODING: 'gzip, deflate', }
    var1407 = b''
    var450 = None
    var3271 = None
    var3118 = None
    var3221 = None
    var4430 = None

    def __init__(self, arg1092, arg1935, *, params=None, headers=None, skip_auto_headers=frozenset(), data=None, cookies=None, auth=None, version=http.HttpVersion11, compress=None, chunked=None, expect100=False, loop=None, response_class=None, proxy=None, proxy_auth=None, proxy_from_env=False, timer=None):
        if (var1693 is None):
            var1693 = asyncio.get_event_loop()
        assert isinstance(arg1935, URL), url
        assert isinstance(proxy, (URL, type(None))), proxy
        if params:
            var1877 = MultiDict(arg1935.query)
            var330 = arg1935.with_query(params)
            var1877.extend(var330.query)
            arg1935 = arg1935.with_query(var1877)
        self.attribute1914 = arg1935.with_fragment(None)
        self.attribute683 = arg1935
        self.attribute669 = arg1092.upper()
        self.attribute126 = chunked
        self.attribute1089 = compress
        self.attribute1757 = var1693
        self.attribute1138 = None
        self.attribute333 = (response_class or ClientResponse)
        self.attribute2226 = (timer if (timer is not None) else TimerNoop())
        if var1693.get_debug():
            self.attribute2397 = traceback.extract_stack(sys._getframe(1))
        self.function2543(version)
        self.function2715(arg1935)
        self.function229(headers)
        self.function1475(skip_auto_headers)
        self.function355(cookies)
        self.function964(data)
        self.function2129(var450)
        self.function150(proxy, proxy_auth, proxy_from_env)
        self.function71(data, skip_auto_headers)
        self.function1156()
        self.function2358(expect100)

    @property
    def function1365(self):
        return self.attribute1914.function1365

    @property
    def function251(self):
        return self.attribute1914.function251

    @property
    def function1813(self):
        return var1191(self.attribute1914, self.attribute669, self.attribute1936)

    def function2715(self, arg1829):
        'Update destination host, port and connection type (ssl).'
        if (not arg1829.function1365):
            raise ValueError("Could not parse hostname from URL '{}'".format(arg1829))
        (var4435, var1335) = (arg1829.user, arg1829.var1335)
        if var4435:
            self.attribute1591 = helpers.BasicAuth(var4435, (password or ''))
        var2716 = arg1829.var2716
        self.attribute720 = (var2716 in ('https', 'wss'))

    def function2543(self, arg2302):
        "Convert request version to two elements tuple.\n\n        parser HTTP version '1.1' => (1, 1)\n        "
        if isinstance(arg2302, str):
            var4344 = [var1609.strip() for var1609 in arg2302.split('.', 1)]
            try:
                arg2302 = (int(var4344[0]), int(var4344[1]))
            except ValueError:
                raise ValueError('Can not parse http version number: {}'.format(arg2302)) from None
        self.attribute2303 = arg2302

    def function229(self, arg2260):
        'Update request headers.'
        self.attribute1936 = CIMultiDict()
        if arg2260:
            if isinstance(arg2260, (dict, MultiDictProxy, MultiDict)):
                arg2260 = arg2260.items()
            for (var3643, var1822) in arg2260:
                self.arg2260.add(var3643, var1822)

    def function1475(self, arg1591):
        self.attribute1742 = arg1591
        var547 = (set(self.attribute1936) | arg1591)
        for (var679, var1771) in self.var3979.items():
            if (var679 not in var547):
                self.attribute1936.add(var679, var1771)
        if (hdrs.HOST not in var547):
            var986 = self.attribute1914.raw_host
            if (not self.attribute1914.is_default_port()):
                var986 += (':' + str(self.attribute1914.function251))
            self.attribute1936[hdrs.HOST] = var986
        if (hdrs.USER_AGENT not in var547):
            self.attribute1936[hdrs.USER_AGENT] = SERVER_SOFTWARE

    def function355(self, arg709):
        'Update request cookies header.'
        if (not arg709):
            return
        var557 = SimpleCookie()
        if (hdrs.COOKIE in self.attribute1936):
            var557.load(self.attribute1936.get(hdrs.COOKIE, ''))
            del self.attribute1936[hdrs.COOKIE]
        for (var4295, var2855) in arg709.items():
            if isinstance(var2855, Morsel):
                var519 = var2855.get(var2855.key, Morsel())
                var519.set(var2855.key, var2855.var2855, var2855.coded_value)
                var557[var4295] = var519
            else:
                var557[var4295] = var2855
        self.attribute1936[hdrs.COOKIE] = var557.output(header='', sep=';').strip()

    def function964(self, arg152):
        'Set request content encoding.'
        if (not arg152):
            return
        var3819 = self.attribute1936.get(hdrs.CONTENT_ENCODING, '').lower()
        if var3819:
            if self.attribute1089:
                raise ValueError('compress can not be set if Content-Encoding header is set')
        elif self.attribute1089:
            if (not isinstance(self.attribute1089, str)):
                self.attribute1089 = 'deflate'
            self.attribute1936[hdrs.CONTENT_ENCODING] = self.attribute1089
            self.attribute126 = True

    def function1156(self):
        'Analyze transfer-encoding header.'
        var1594 = self.attribute1936.get(hdrs.TRANSFER_ENCODING, '').lower()
        if ('chunked' in var1594):
            if self.attribute126:
                raise ValueError('chunked can not be set if "Transfer-Encoding: chunked" header is set')
        elif self.attribute126:
            if (hdrs.CONTENT_LENGTH in self.attribute1936):
                raise ValueError('chunked can not be set if Content-Length header is set')
            self.attribute1936[hdrs.TRANSFER_ENCODING] = 'chunked'
        elif (hdrs.CONTENT_LENGTH not in self.attribute1936):
            self.attribute1936[hdrs.CONTENT_LENGTH] = str(len(self.var1407))

    def function2129(self, var450):
        'Set basic auth.'
        if (var450 is None):
            var450 = self.var450
        if (var450 is None):
            return
        if (not isinstance(var450, helpers.BasicAuth)):
            raise TypeError('BasicAuth() tuple is required instead')
        self.attribute1936[hdrs.AUTHORIZATION] = var450.encode()

    def function71(self, var1407, arg2305):
        if (not var1407):
            return
        if isinstance(var1407, FormData):
            var1407 = var1407()
        try:
            var1407 = payload.PAYLOAD_REGISTRY.get(var1407, disposition=None)
        except payload.LookupError:
            var1407 = FormData(var1407)()
        self.attribute25 = var1407
        if (not self.attribute126):
            if (hdrs.CONTENT_LENGTH not in self.attribute1936):
                var2818 = var1407.var2818
                if (var2818 is None):
                    self.attribute126 = True
                elif (hdrs.CONTENT_LENGTH not in self.attribute1936):
                    self.attribute1936[hdrs.CONTENT_LENGTH] = str(var2818)
        if ((hdrs.CONTENT_TYPE not in self.attribute1936) and (hdrs.CONTENT_TYPE not in arg2305)):
            self.attribute1936[hdrs.CONTENT_TYPE] = var1407.content_type
        if var1407.headers:
            for (var4636, var3781) in var1407.headers.items():
                if (var4636 not in self.attribute1936):
                    self.attribute1936[var4636] = var3781

    def function2358(self, arg144=False):
        if arg144:
            self.attribute1936[hdrs.EXPECT] = '100-continue'
        elif (self.attribute1936.get(hdrs.EXPECT, '').lower() == '100-continue'):
            arg144 = True
        if arg144:
            self.attribute143 = helpers.create_future(self.attribute1757)

    def function150(self, arg2369, arg455, arg1310):
        if (proxy_from_env and (not arg2369)):
            var3723 = getproxies().get(self.attribute683.scheme)
            arg2369 = (URL(var3723) if var3723 else None)
        if (proxy and (not (arg2369.scheme == 'http'))):
            raise ValueError('Only http proxies are supported')
        if (proxy_auth and (not isinstance(arg455, helpers.BasicAuth))):
            raise ValueError('proxy_auth must be None or BasicAuth() tuple')
        self.attribute1478 = arg2369
        self.attribute563 = arg455

    def function707(self):
        if (self.arg2302 < HttpVersion10):
            return False
        if (self.arg2302 == HttpVersion10):
            if (self.attribute1936.get(hdrs.CONNECTION) == 'keep-alive'):
                return True
            else:
                return False
        elif (self.attribute1936.get(hdrs.CONNECTION) == 'close'):
            return False
        return True

    @asyncio.coroutine
    def function189(self, arg1083, arg270):
        'Support coroutines that yields bytes objects.'
        if (self.var4430 is not None):
            yield from arg1083.drain()
            yield from self.attribute143
        try:
            if isinstance(self.var1407, payload.Payload):
                yield from self.var1407.write(arg1083)
            else:
                if isinstance(self.var1407, (bytes, bytearray)):
                    self.attribute25 = (self.var1407,)
                for var3335 in self.var1407:
                    arg1083.write(var3335)
            yield from arg1083.write_eof()
        except OSError as var3753:
            var520 = ClientOSError(var3753.errno, ('Can not write request body for %s' % self.attribute1914))
            var520.__context__ = var3753
            var520.__cause__ = var3753
            arg270.protocol.set_exception(var520)
        except Exception as var3753:
            arg270.protocol.set_exception(var3753)
        finally:
            self.attribute1667 = None

    def function516(self, arg270):
        if (self.attribute669 == hdrs.METH_CONNECT):
            var3263 = '{}:{}'.format(self.attribute1914.raw_host, self.attribute1914.function251)
        elif (self.attribute1478 and (not self.attribute720)):
            var3263 = str(self.attribute1914)
        else:
            var3263 = self.attribute1914.raw_path
            if self.attribute1914.raw_query_string:
                var3263 += ('?' + self.attribute1914.raw_query_string)
        arg1083 = PayloadWriter(arg270.arg1083, self.attribute1757)
        if self.attribute1089:
            arg1083.enable_compression(self.attribute1089)
        if (self.attribute126 is not None):
            arg1083.enable_chunking()
        if ((self.attribute669 in self.var2166) and (hdrs.CONTENT_TYPE not in self.arg2305) and (hdrs.CONTENT_TYPE not in self.attribute1936)):
            self.attribute1936[hdrs.CONTENT_TYPE] = 'application/octet-stream'
        var648 = self.attribute1936.get(hdrs.CONNECTION)
        if (not var648):
            if self.function707():
                if (self.arg2302 == HttpVersion10):
                    var648 = 'keep-alive'
            elif (self.arg2302 == HttpVersion11):
                var648 = 'close'
        if (var648 is not None):
            self.attribute1936[hdrs.CONNECTION] = var648
        var1203 = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format(self.attribute669, var3263, self.arg2302)
        arg1083.write_headers(var1203, self.attribute1936)
        self.attribute1667 = helpers.ensure_future(self.function189(arg1083, arg270), loop=self.attribute1757)
        self.attribute596 = self.var3118(self.attribute669, self.attribute683, writer=self.var3221, continue100=self.var4430, timer=self.attribute2226, request_info=self.function1813)
        self.var3271._post_init(self.attribute1757)
        return self.var3271

    @asyncio.coroutine
    def function1516(self):
        if (self.var3221 is not None):
            try:
                yield from self.attribute1667
            finally:
                self.attribute1667 = None

    def function1574(self):
        if (self.var3221 is not None):
            if (not self.attribute1757.is_closed()):
                self.var3221.cancel()
            self.attribute1667 = None
Exemplo n.º 32
0
class ClientRequest:

    GET_METHODS = {hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS}
    POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    ALL_METHODS = GET_METHODS.union(POST_METHODS).union(
        {hdrs.METH_DELETE, hdrs.METH_TRACE})

    DEFAULT_HEADERS = {
        hdrs.ACCEPT: '*/*',
        hdrs.ACCEPT_ENCODING: 'gzip, deflate',
    }

    SERVER_SOFTWARE = HttpMessage.SERVER_SOFTWARE

    body = b''
    auth = None
    response = None
    response_class = None

    _writer = None  # async task for streaming data
    _continue = None  # waiter future for '100 Continue' response

    # N.B.
    # Adding __del__ method with self._writer closing doesn't make sense
    # because _writer is instance method, thus it keeps a reference to self.
    # Until writer has finished finalizer will not be called.

    def __init__(self,
                 method,
                 url,
                 *,
                 params=None,
                 headers=None,
                 skip_auto_headers=frozenset(),
                 data=None,
                 cookies=None,
                 auth=None,
                 encoding='utf-8',
                 version=aiohttp.HttpVersion11,
                 compress=None,
                 chunked=None,
                 expect100=False,
                 loop=None,
                 response_class=None,
                 proxy=None,
                 proxy_auth=None,
                 timeout=5 * 60):

        if loop is None:
            loop = asyncio.get_event_loop()

        self.url = url
        self.method = method.upper()
        self.encoding = encoding
        self.chunked = chunked
        self.compress = compress
        self.loop = loop
        self.response_class = response_class or ClientResponse
        self._timeout = timeout

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        self.update_version(version)
        self.update_host(url)
        self.update_path(params)
        self.update_headers(headers)
        self.update_auto_headers(skip_auto_headers)
        self.update_cookies(cookies)
        self.update_content_encoding(data)
        self.update_auth(auth)
        self.update_proxy(proxy, proxy_auth)

        self.update_body_from_data(data, skip_auto_headers)
        self.update_transfer_encoding()
        self.update_expect_continue(expect100)

    def update_host(self, url):
        """Update destination host, port and connection type (ssl)."""
        url_parsed = urllib.parse.urlsplit(url)

        # check for network location part
        netloc = url_parsed.netloc
        if not netloc:
            raise ValueError('Host could not be detected.')

        # get host/port
        host = url_parsed.hostname
        if not host:
            raise ValueError('Host could not be detected.')

        try:
            port = url_parsed.port
        except ValueError:
            raise ValueError('Port number could not be converted.') from None

        # check domain idna encoding
        try:
            host = host.encode('idna').decode('utf-8')
            netloc = self.make_netloc(host, url_parsed.port)
        except UnicodeError:
            raise ValueError('URL has an invalid label.')

        # basic auth info
        username, password = url_parsed.username, url_parsed.password
        if username:
            self.auth = helpers.BasicAuth(username, password or '')

        # Record entire netloc for usage in host header
        self.netloc = netloc

        scheme = url_parsed.scheme
        self.ssl = scheme in ('https', 'wss')

        # set port number if it isn't already set
        if not port:
            if self.ssl:
                port = HTTPS_PORT
            else:
                port = HTTP_PORT

        self.host, self.port, self.scheme = host, port, scheme

    def make_netloc(self, host, port):
        ret = host
        if port:
            ret = ret + ':' + str(port)
        return ret

    def update_version(self, version):
        """Convert request version to two elements tuple.

        parser HTTP version '1.1' => (1, 1)
        """
        if isinstance(version, str):
            v = [l.strip() for l in version.split('.', 1)]
            try:
                version = int(v[0]), int(v[1])
            except ValueError:
                raise ValueError(
                    'Can not parse http version number: {}'.format(
                        version)) from None
        self.version = version

    def update_path(self, params):
        """Build path."""
        # extract path
        scheme, netloc, path, query, fragment = urllib.parse.urlsplit(self.url)
        if not path:
            path = '/'

        if isinstance(params, collections.Mapping):
            params = list(params.items())

        if params:
            if not isinstance(params, str):
                params = urllib.parse.urlencode(params)
            if query:
                query = '%s&%s' % (query, params)
            else:
                query = params

        self.path = urllib.parse.urlunsplit(
            ('', '', helpers.requote_uri(path), query, ''))
        self.url = urllib.parse.urlunsplit(
            (scheme, netloc, self.path, '', fragment))

    def update_headers(self, headers):
        """Update request headers."""
        self.headers = CIMultiDict()
        if headers:
            if isinstance(headers, dict):
                headers = headers.items()
            elif isinstance(headers, (MultiDictProxy, MultiDict)):
                headers = headers.items()

            for key, value in headers:
                self.headers.add(key, value)

    def update_auto_headers(self, skip_auto_headers):
        self.skip_auto_headers = skip_auto_headers
        used_headers = set(self.headers) | skip_auto_headers

        for hdr, val in self.DEFAULT_HEADERS.items():
            if hdr not in used_headers:
                self.headers.add(hdr, val)

        # add host
        if hdrs.HOST not in used_headers:
            self.headers[hdrs.HOST] = self.netloc

        if hdrs.USER_AGENT not in used_headers:
            self.headers[hdrs.USER_AGENT] = self.SERVER_SOFTWARE

    def update_cookies(self, cookies):
        """Update request cookies header."""
        if not cookies:
            return

        c = http.cookies.SimpleCookie()
        if hdrs.COOKIE in self.headers:
            c.load(self.headers.get(hdrs.COOKIE, ''))
            del self.headers[hdrs.COOKIE]

        if isinstance(cookies, dict):
            cookies = cookies.items()

        for name, value in cookies:
            if isinstance(value, http.cookies.Morsel):
                c[value.key] = value.value
            else:
                c[name] = value

        self.headers[hdrs.COOKIE] = c.output(header='', sep=';').strip()

    def update_content_encoding(self, data):
        """Set request content encoding."""
        if not data:
            return

        enc = self.headers.get(hdrs.CONTENT_ENCODING, '').lower()
        if enc:
            if self.compress is not False:
                self.compress = enc
                # enable chunked, no need to deal with length
                self.chunked = True
        elif self.compress:
            if not isinstance(self.compress, str):
                self.compress = 'deflate'
            self.headers[hdrs.CONTENT_ENCODING] = self.compress
            self.chunked = True  # enable chunked, no need to deal with length

    def update_auth(self, auth):
        """Set basic auth."""
        if auth is None:
            auth = self.auth
        if auth is None:
            return

        if not isinstance(auth, helpers.BasicAuth):
            raise TypeError('BasicAuth() tuple is required instead')

        self.headers[hdrs.AUTHORIZATION] = auth.encode()

    def update_body_from_data(self, data, skip_auto_headers):
        if not data:
            return

        if isinstance(data, str):
            data = data.encode(self.encoding)

        if isinstance(data, (bytes, bytearray)):
            self.body = data
            if (hdrs.CONTENT_TYPE not in self.headers
                    and hdrs.CONTENT_TYPE not in skip_auto_headers):
                self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'
            if hdrs.CONTENT_LENGTH not in self.headers and not self.chunked:
                self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

        elif isinstance(
                data,
            (asyncio.StreamReader, streams.StreamReader, streams.DataQueue)):
            self.body = data

        elif asyncio.iscoroutine(data):
            self.body = data
            if (hdrs.CONTENT_LENGTH not in self.headers
                    and self.chunked is None):
                self.chunked = True

        elif isinstance(data, io.IOBase):
            assert not isinstance(data, io.StringIO), \
                'attempt to send text data instead of binary'
            self.body = data
            if not self.chunked and isinstance(data, io.BytesIO):
                # Not chunking if content-length can be determined
                size = len(data.getbuffer())
                self.headers[hdrs.CONTENT_LENGTH] = str(size)
                self.chunked = False
            elif not self.chunked and isinstance(data, io.BufferedReader):
                # Not chunking if content-length can be determined
                try:
                    size = os.fstat(data.fileno()).st_size - data.tell()
                    self.headers[hdrs.CONTENT_LENGTH] = str(size)
                    self.chunked = False
                except OSError:
                    # data.fileno() is not supported, e.g.
                    # io.BufferedReader(io.BytesIO(b'data'))
                    self.chunked = True
            else:
                self.chunked = True

            if hasattr(data, 'mode'):
                if data.mode == 'r':
                    raise ValueError('file {!r} should be open in binary mode'
                                     ''.format(data))
            if (hdrs.CONTENT_TYPE not in self.headers
                    and hdrs.CONTENT_TYPE not in skip_auto_headers
                    and hasattr(data, 'name')):
                mime = mimetypes.guess_type(data.name)[0]
                mime = 'application/octet-stream' if mime is None else mime
                self.headers[hdrs.CONTENT_TYPE] = mime

        elif isinstance(data, MultipartWriter):
            self.body = data.serialize()
            self.headers.update(data.headers)
            self.chunked = self.chunked or 8192

        else:
            if not isinstance(data, helpers.FormData):
                data = helpers.FormData(data)

            self.body = data(self.encoding)

            if (hdrs.CONTENT_TYPE not in self.headers
                    and hdrs.CONTENT_TYPE not in skip_auto_headers):
                self.headers[hdrs.CONTENT_TYPE] = data.content_type

            if data.is_multipart:
                self.chunked = self.chunked or 8192
            else:
                if (hdrs.CONTENT_LENGTH not in self.headers
                        and not self.chunked):
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_transfer_encoding(self):
        """Analyze transfer-encoding header."""
        te = self.headers.get(hdrs.TRANSFER_ENCODING, '').lower()

        if self.chunked:
            if hdrs.CONTENT_LENGTH in self.headers:
                del self.headers[hdrs.CONTENT_LENGTH]
            if 'chunked' not in te:
                self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'

            self.chunked = self.chunked if type(self.chunked) is int else 8192
        else:
            if 'chunked' in te:
                self.chunked = 8192
            else:
                self.chunked = None
                if hdrs.CONTENT_LENGTH not in self.headers:
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_expect_continue(self, expect=False):
        if expect:
            self.headers[hdrs.EXPECT] = '100-continue'
        elif self.headers.get(hdrs.EXPECT, '').lower() == '100-continue':
            expect = True

        if expect:
            self._continue = helpers.create_future(self.loop)

    def update_proxy(self, proxy, proxy_auth):
        if proxy and not proxy.startswith('http://'):
            raise ValueError("Only http proxies are supported")
        if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
            raise ValueError("proxy_auth must be None or BasicAuth() tuple")
        self.proxy = proxy
        self.proxy_auth = proxy_auth

    @asyncio.coroutine
    def write_bytes(self, request, reader):
        """Support coroutines that yields bytes objects."""
        # 100 response
        if self._continue is not None:
            yield from self._continue

        try:
            if asyncio.iscoroutine(self.body):
                request.transport.set_tcp_nodelay(True)
                exc = None
                value = None
                stream = self.body

                while True:
                    try:
                        if exc is not None:
                            result = stream.throw(exc)
                        else:
                            result = stream.send(value)
                    except StopIteration as exc:
                        if isinstance(exc.value, bytes):
                            yield from request.write(exc.value, drain=True)
                        break
                    except:
                        self.response.close()
                        raise

                    if isinstance(result, asyncio.Future):
                        exc = None
                        value = None
                        try:
                            value = yield result
                        except Exception as err:
                            exc = err
                    elif isinstance(result, (bytes, bytearray)):
                        yield from request.write(result, drain=True)
                        value = None
                    else:
                        raise ValueError('Bytes object is expected, got: %s.' %
                                         type(result))

            elif isinstance(self.body,
                            (asyncio.StreamReader, streams.StreamReader)):
                request.transport.set_tcp_nodelay(True)
                chunk = yield from self.body.read(streams.DEFAULT_LIMIT)
                while chunk:
                    yield from request.write(chunk, drain=True)
                    chunk = yield from self.body.read(streams.DEFAULT_LIMIT)

            elif isinstance(self.body, streams.DataQueue):
                request.transport.set_tcp_nodelay(True)
                while True:
                    try:
                        chunk = yield from self.body.read()
                        if chunk is EOF_MARKER:
                            break
                        yield from request.write(chunk, drain=True)
                    except streams.EofStream:
                        break

            elif isinstance(self.body, io.IOBase):
                chunk = self.body.read(self.chunked)
                while chunk:
                    request.write(chunk)
                    chunk = self.body.read(self.chunked)
                request.transport.set_tcp_nodelay(True)

            else:
                if isinstance(self.body, (bytes, bytearray)):
                    self.body = (self.body, )

                for chunk in self.body:
                    request.write(chunk)
                request.transport.set_tcp_nodelay(True)

        except Exception as exc:
            new_exc = aiohttp.ClientRequestError(
                'Can not write request body for %s' % self.url)
            new_exc.__context__ = exc
            new_exc.__cause__ = exc
            reader.set_exception(new_exc)
        else:
            assert request.transport.tcp_nodelay
            try:
                ret = request.write_eof()
                # NB: in asyncio 3.4.1+ StreamWriter.drain() is coroutine
                # see bug #170
                if (asyncio.iscoroutine(ret)
                        or isinstance(ret, asyncio.Future)):
                    yield from ret
            except Exception as exc:
                new_exc = aiohttp.ClientRequestError(
                    'Can not write request body for %s' % self.url)
                new_exc.__context__ = exc
                new_exc.__cause__ = exc
                reader.set_exception(new_exc)

        self._writer = None

    def send(self, writer, reader):
        writer.set_tcp_cork(True)
        request = aiohttp.Request(writer, self.method, self.path, self.version)

        if self.compress:
            request.add_compression_filter(self.compress)

        if self.chunked is not None:
            request.enable_chunked_encoding()
            request.add_chunking_filter(self.chunked)

        # set default content-type
        if (self.method in self.POST_METHODS
                and hdrs.CONTENT_TYPE not in self.skip_auto_headers
                and hdrs.CONTENT_TYPE not in self.headers):
            self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'

        for k, value in self.headers.items():
            request.add_header(k, value)
        request.send_headers()

        self._writer = helpers.ensure_future(self.write_bytes(request, reader),
                                             loop=self.loop)

        self.response = self.response_class(self.method,
                                            self.url,
                                            self.host,
                                            writer=self._writer,
                                            continue100=self._continue,
                                            timeout=self._timeout)
        self.response._post_init(self.loop)
        return self.response

    @asyncio.coroutine
    def close(self):
        if self._writer is not None:
            try:
                yield from self._writer
            finally:
                self._writer = None

    def terminate(self):
        if self._writer is not None:
            if not self.loop.is_closed():
                self._writer.cancel()
            self._writer = None
Exemplo n.º 33
0
 def function767(self, arg211):
     'Parses RFC 5322 headers from a stream.\n\n        Line continuations are supported. Returns list of header name\n        and value pairs. Header name is in upper case.\n        '
     var3066 = CIMultiDict()
     var2086 = []
     var4022 = 1
     var2842 = arg211[1]
     var1431 = len(arg211)
     while line:
         var1685 = len(var2842)
         try:
             (var3564, var724) = var2842.split(b':', 1)
         except ValueError:
             raise InvalidHeader(var2842) from None
         var3564 = var3564.strip(b' \t')
         if var4019.search(var3564):
             raise InvalidHeader(var3564)
         var4022 += 1
         var2842 = arg211[var4022]
         var3115 = (line and (var2842[0] in (32, 9)))
         if var3115:
             var724 = [var724]
             while continuation:
                 var1685 += len(var2842)
                 if (var1685 > self.attribute1698):
                     raise LineTooLong(
                         'request header field {}'.format(
                             var3564.decode('utf8', 'xmlcharrefreplace')),
                         self.attribute1698)
                 var724.append(var2842)
                 var4022 += 1
                 if (var4022 < var1431):
                     var2842 = arg211[var4022]
                     if var2842:
                         var3115 = (var2842[0] in (32, 9))
                 else:
                     var2842 = b''
                     break
             var724 = b''.join(var724)
         elif (var1685 > self.attribute1698):
             raise LineTooLong(
                 'request header field {}'.format(
                     var3564.decode('utf8', 'xmlcharrefreplace')),
                 self.attribute1698)
         var724 = var724.strip()
         var4704 = istr(var3564.decode('utf-8', 'surrogateescape'))
         var3249 = var724.decode('utf-8', 'surrogateescape')
         var3066.add(var4704, var3249)
         var2086.append((var3564, var724))
     var4441 = None
     var1752 = None
     var4099 = False
     var1171 = False
     var2086 = tuple(var2086)
     var2967 = var3066.get(hdrs.CONNECTION)
     if var2967:
         var502 = var2967.lower()
         if (var502 == 'close'):
             var4441 = True
         elif (var502 == 'keep-alive'):
             var4441 = False
         elif (var502 == 'upgrade'):
             var4099 = True
     var2201 = var3066.get(hdrs.CONTENT_ENCODING)
     if var2201:
         var2201 = var2201.lower()
         if (var2201 in ('gzip', 'deflate')):
             var1752 = var2201
     var768 = var3066.get(hdrs.TRANSFER_ENCODING)
     if (te and ('chunked' in var768.lower())):
         var1171 = True
     return (var3066, var2086, var4441, var1752, var4099, var1171)
Exemplo n.º 34
0
class ClientRequest:

    GET_METHODS = {hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS}
    POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    ALL_METHODS = GET_METHODS.union(POST_METHODS).union(
        {hdrs.METH_DELETE, hdrs.METH_TRACE})

    DEFAULT_HEADERS = {
        hdrs.ACCEPT: '*/*',
        hdrs.ACCEPT_ENCODING: 'gzip, deflate',
    }

    SERVER_SOFTWARE = HttpMessage.SERVER_SOFTWARE

    body = b''
    auth = None
    response = None
    response_class = None

    _writer = None  # async task for streaming data
    _continue = None  # waiter future for '100 Continue' response

    # N.B.
    # Adding __del__ method with self._writer closing doesn't make sense
    # because _writer is instance method, thus it keeps a reference to self.
    # Until writer has finished finalizer will not be called.

    def __init__(self, method, url, *,
                 params=None, headers=None, skip_auto_headers=frozenset(),
                 data=None, cookies=None,
                 auth=None, encoding='utf-8',
                 version=aiohttp.HttpVersion11, compress=None,
                 chunked=None, expect100=False,
                 loop=None, response_class=None,
                 proxy=None, proxy_auth=None,
                 timeout=5*60):

        if loop is None:
            loop = asyncio.get_event_loop()

        assert isinstance(url, URL), url
        assert isinstance(proxy, (URL, type(None))), proxy

        if params:
            q = MultiDict(url.query)
            url2 = url.with_query(params)
            q.extend(url2.query)
            url = url.with_query(q)
        self.url = url.with_fragment(None)
        self.method = method.upper()
        self.encoding = encoding
        self.chunked = chunked
        self.compress = compress
        self.loop = loop
        self.response_class = response_class or ClientResponse
        self._timeout = timeout

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        self.update_version(version)
        self.update_host(url)
        self.update_headers(headers)
        self.update_auto_headers(skip_auto_headers)
        self.update_cookies(cookies)
        self.update_content_encoding(data)
        self.update_auth(auth)
        self.update_proxy(proxy, proxy_auth)

        self.update_body_from_data(data, skip_auto_headers)
        self.update_transfer_encoding()
        self.update_expect_continue(expect100)

    @property
    def host(self):
        return self.url.host

    @property
    def port(self):
        return self.url.port

    def update_host(self, url):
        """Update destination host, port and connection type (ssl)."""
        # get host/port
        if not url.host:
            raise ValueError('Host could not be detected.')

        # basic auth info
        username, password = url.user, url.password
        if username:
            self.auth = helpers.BasicAuth(username, password or '')

        # Record entire netloc for usage in host header

        scheme = url.scheme
        self.ssl = scheme in ('https', 'wss')

    def update_version(self, version):
        """Convert request version to two elements tuple.

        parser HTTP version '1.1' => (1, 1)
        """
        if isinstance(version, str):
            v = [l.strip() for l in version.split('.', 1)]
            try:
                version = int(v[0]), int(v[1])
            except ValueError:
                raise ValueError(
                    'Can not parse http version number: {}'
                    .format(version)) from None
        self.version = version

    def update_headers(self, headers):
        """Update request headers."""
        self.headers = CIMultiDict()
        if headers:
            if isinstance(headers, dict):
                headers = headers.items()
            elif isinstance(headers, (MultiDictProxy, MultiDict)):
                headers = headers.items()

            for key, value in headers:
                self.headers.add(key, value)

    def update_auto_headers(self, skip_auto_headers):
        self.skip_auto_headers = skip_auto_headers
        used_headers = set(self.headers) | skip_auto_headers

        for hdr, val in self.DEFAULT_HEADERS.items():
            if hdr not in used_headers:
                self.headers.add(hdr, val)

        # add host
        if hdrs.HOST not in used_headers:
            netloc = self.url.host
            if not self.url.is_default_port():
                netloc += ':' + str(self.url.port)
            self.headers[hdrs.HOST] = netloc

        if hdrs.USER_AGENT not in used_headers:
            self.headers[hdrs.USER_AGENT] = self.SERVER_SOFTWARE

    def update_cookies(self, cookies):
        """Update request cookies header."""
        if not cookies:
            return

        c = http.cookies.SimpleCookie()
        if hdrs.COOKIE in self.headers:
            c.load(self.headers.get(hdrs.COOKIE, ''))
            del self.headers[hdrs.COOKIE]

        for name, value in cookies.items():
            if isinstance(value, http.cookies.Morsel):
                c[value.key] = value.value
            else:
                c[name] = value

        self.headers[hdrs.COOKIE] = c.output(header='', sep=';').strip()

    def update_content_encoding(self, data):
        """Set request content encoding."""
        if not data:
            return

        enc = self.headers.get(hdrs.CONTENT_ENCODING, '').lower()
        if enc:
            if self.compress is not False:
                self.compress = enc
                # enable chunked, no need to deal with length
                self.chunked = True
        elif self.compress:
            if not isinstance(self.compress, str):
                self.compress = 'deflate'
            self.headers[hdrs.CONTENT_ENCODING] = self.compress
            self.chunked = True  # enable chunked, no need to deal with length

    def update_auth(self, auth):
        """Set basic auth."""
        if auth is None:
            auth = self.auth
        if auth is None:
            return

        if not isinstance(auth, helpers.BasicAuth):
            raise TypeError('BasicAuth() tuple is required instead')

        self.headers[hdrs.AUTHORIZATION] = auth.encode()

    def update_body_from_data(self, data, skip_auto_headers):
        if not data:
            return

        if isinstance(data, str):
            data = data.encode(self.encoding)

        if isinstance(data, (bytes, bytearray)):
            self.body = data
            if (hdrs.CONTENT_TYPE not in self.headers and
                    hdrs.CONTENT_TYPE not in skip_auto_headers):
                self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'
            if hdrs.CONTENT_LENGTH not in self.headers and not self.chunked:
                self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

        elif isinstance(data, (asyncio.StreamReader, streams.StreamReader,
                               streams.DataQueue)):
            self.body = data

        elif asyncio.iscoroutine(data):
            self.body = data
            if (hdrs.CONTENT_LENGTH not in self.headers and
                    self.chunked is None):
                self.chunked = True

        elif isinstance(data, io.IOBase):
            assert not isinstance(data, io.StringIO), \
                'attempt to send text data instead of binary'
            self.body = data
            if not self.chunked and isinstance(data, io.BytesIO):
                # Not chunking if content-length can be determined
                size = len(data.getbuffer())
                self.headers[hdrs.CONTENT_LENGTH] = str(size)
                self.chunked = False
            elif (not self.chunked and
                  isinstance(data, (io.BufferedReader, io.BufferedRandom))):
                # Not chunking if content-length can be determined
                try:
                    size = os.fstat(data.fileno()).st_size - data.tell()
                    self.headers[hdrs.CONTENT_LENGTH] = str(size)
                    self.chunked = False
                except OSError:
                    # data.fileno() is not supported, e.g.
                    # io.BufferedReader(io.BytesIO(b'data'))
                    self.chunked = True
            else:
                self.chunked = True

            if hasattr(data, 'mode'):
                if data.mode == 'r':
                    raise ValueError('file {!r} should be open in binary mode'
                                     ''.format(data))
            if (hdrs.CONTENT_TYPE not in self.headers and
                hdrs.CONTENT_TYPE not in skip_auto_headers and
                    hasattr(data, 'name')):
                mime = mimetypes.guess_type(data.name)[0]
                mime = 'application/octet-stream' if mime is None else mime
                self.headers[hdrs.CONTENT_TYPE] = mime

        elif isinstance(data, MultipartWriter):
            self.body = data.serialize()
            self.headers.update(data.headers)
            self.chunked = self.chunked or 8192

        else:
            if not isinstance(data, helpers.FormData):
                data = helpers.FormData(data)

            self.body = data(self.encoding)

            if (hdrs.CONTENT_TYPE not in self.headers and
                    hdrs.CONTENT_TYPE not in skip_auto_headers):
                self.headers[hdrs.CONTENT_TYPE] = data.content_type

            if data.is_multipart:
                self.chunked = self.chunked or 8192
            else:
                if (hdrs.CONTENT_LENGTH not in self.headers and
                        not self.chunked):
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_transfer_encoding(self):
        """Analyze transfer-encoding header."""
        te = self.headers.get(hdrs.TRANSFER_ENCODING, '').lower()

        if self.chunked:
            if hdrs.CONTENT_LENGTH in self.headers:
                del self.headers[hdrs.CONTENT_LENGTH]
            if 'chunked' not in te:
                self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'

            self.chunked = self.chunked if type(self.chunked) is int else 8192
        else:
            if 'chunked' in te:
                self.chunked = 8192
            else:
                self.chunked = None
                if hdrs.CONTENT_LENGTH not in self.headers:
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_expect_continue(self, expect=False):
        if expect:
            self.headers[hdrs.EXPECT] = '100-continue'
        elif self.headers.get(hdrs.EXPECT, '').lower() == '100-continue':
            expect = True

        if expect:
            self._continue = helpers.create_future(self.loop)

    def update_proxy(self, proxy, proxy_auth):
        if proxy and not proxy.scheme == 'http':
            raise ValueError("Only http proxies are supported")
        if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
            raise ValueError("proxy_auth must be None or BasicAuth() tuple")
        self.proxy = proxy
        self.proxy_auth = proxy_auth

    @asyncio.coroutine
    def write_bytes(self, request, reader):
        """Support coroutines that yields bytes objects."""
        # 100 response
        if self._continue is not None:
            yield from self._continue

        try:
            if asyncio.iscoroutine(self.body):
                request.transport.set_tcp_nodelay(True)
                exc = None
                value = None
                stream = self.body

                while True:
                    try:
                        if exc is not None:
                            result = stream.throw(exc)
                        else:
                            result = stream.send(value)
                    except StopIteration as exc:
                        if isinstance(exc.value, bytes):
                            yield from request.write(exc.value, drain=True)
                        break
                    except:
                        self.response.close()
                        raise

                    if isinstance(result, asyncio.Future):
                        exc = None
                        value = None
                        try:
                            value = yield result
                        except Exception as err:
                            exc = err
                    elif isinstance(result, (bytes, bytearray)):
                        yield from request.write(result, drain=True)
                        value = None
                    else:
                        raise ValueError(
                            'Bytes object is expected, got: %s.' %
                            type(result))

            elif isinstance(self.body, (asyncio.StreamReader,
                                        streams.StreamReader)):
                request.transport.set_tcp_nodelay(True)
                chunk = yield from self.body.read(streams.DEFAULT_LIMIT)
                while chunk:
                    yield from request.write(chunk, drain=True)
                    chunk = yield from self.body.read(streams.DEFAULT_LIMIT)

            elif isinstance(self.body, streams.DataQueue):
                request.transport.set_tcp_nodelay(True)
                while True:
                    try:
                        chunk = yield from self.body.read()
                        if chunk is EOF_MARKER:
                            break
                        yield from request.write(chunk, drain=True)
                    except streams.EofStream:
                        break

            elif isinstance(self.body, io.IOBase):
                chunk = self.body.read(self.chunked)
                while chunk:
                    request.write(chunk)
                    chunk = self.body.read(self.chunked)
                request.transport.set_tcp_nodelay(True)

            else:
                if isinstance(self.body, (bytes, bytearray)):
                    self.body = (self.body,)

                for chunk in self.body:
                    request.write(chunk)
                request.transport.set_tcp_nodelay(True)

        except Exception as exc:
            new_exc = aiohttp.ClientRequestError(
                'Can not write request body for %s' % self.url)
            new_exc.__context__ = exc
            new_exc.__cause__ = exc
            reader.set_exception(new_exc)
        else:
            assert request.transport.tcp_nodelay
            try:
                ret = request.write_eof()
                # NB: in asyncio 3.4.1+ StreamWriter.drain() is coroutine
                # see bug #170
                if (asyncio.iscoroutine(ret) or
                        isinstance(ret, asyncio.Future)):
                    yield from ret
            except Exception as exc:
                new_exc = aiohttp.ClientRequestError(
                    'Can not write request body for %s' % self.url)
                new_exc.__context__ = exc
                new_exc.__cause__ = exc
                reader.set_exception(new_exc)

        self._writer = None

    def send(self, writer, reader):
        writer.set_tcp_cork(True)
        path = self.url.raw_path
        if self.url.raw_query_string:
            path += '?' + self.url.raw_query_string
        request = aiohttp.Request(writer, self.method, path,
                                  self.version)

        if self.compress:
            request.add_compression_filter(self.compress)

        if self.chunked is not None:
            request.enable_chunked_encoding()
            request.add_chunking_filter(self.chunked)

        # set default content-type
        if (self.method in self.POST_METHODS and
                hdrs.CONTENT_TYPE not in self.skip_auto_headers and
                hdrs.CONTENT_TYPE not in self.headers):
            self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'

        for k, value in self.headers.items():
            request.add_header(k, value)
        request.send_headers()

        self._writer = helpers.ensure_future(
            self.write_bytes(request, reader), loop=self.loop)

        self.response = self.response_class(
            self.method, self.url,
            writer=self._writer, continue100=self._continue,
            timeout=self._timeout)
        self.response._post_init(self.loop)
        return self.response

    @asyncio.coroutine
    def close(self):
        if self._writer is not None:
            try:
                yield from self._writer
            finally:
                self._writer = None

    def terminate(self):
        if self._writer is not None:
            if not self.loop.is_closed():
                self._writer.cancel()
            self._writer = None
Exemplo n.º 35
0
    def parse_headers(self, lines):
        """Parses RFC 5322 headers from a stream.

        Line continuations are supported. Returns list of header name
        and value pairs. Header name is in upper case.
        """
        close_conn = None
        encoding = None
        headers = CIMultiDict()
        raw_headers = []

        lines_idx = 1
        line = lines[1]

        while line:
            header_length = len(line)

            # Parse initial header name : value pair.
            try:
                bname, bvalue = line.split(b':', 1)
            except ValueError:
                raise errors.InvalidHeader(line) from None

            bname = bname.strip(b' \t').upper()
            if HDRRE.search(bname):
                raise errors.InvalidHeader(bname)

            # next line
            lines_idx += 1
            line = lines[lines_idx]

            # consume continuation lines
            continuation = line and line[0] in (32, 9)  # (' ', '\t')

            if continuation:
                bvalue = [bvalue]
                while continuation:
                    header_length += len(line)
                    if header_length > self.max_field_size:
                        raise errors.LineTooLong(
                            'limit request headers fields size')
                    bvalue.append(line)

                    # next line
                    lines_idx += 1
                    line = lines[lines_idx]
                    continuation = line[0] in (32, 9)  # (' ', '\t')
                bvalue = b'\r\n'.join(bvalue)
            else:
                if header_length > self.max_field_size:
                    raise errors.LineTooLong(
                        'limit request headers fields size')

            bvalue = bvalue.strip()

            name = bname.decode('utf-8', 'surrogateescape')
            value = bvalue.decode('utf-8', 'surrogateescape')

            # keep-alive and encoding
            if name == hdrs.CONNECTION:
                v = value.lower()
                if v == 'close':
                    close_conn = True
                elif v == 'keep-alive':
                    close_conn = False
            elif name == hdrs.CONTENT_ENCODING:
                enc = value.lower()
                if enc in ('gzip', 'deflate'):
                    encoding = enc

            headers.add(name, value)
            raw_headers.append((bname, bvalue))

        return headers, raw_headers, close_conn, encoding
Exemplo n.º 36
0
    def parse_headers(
            self,
            lines: List[bytes]
    ) -> Tuple['CIMultiDictProxy[str]', RawHeaders]:
        headers = CIMultiDict()  # type: CIMultiDict[str]
        raw_headers = []

        lines_idx = 1
        line = lines[1]
        line_count = len(lines)

        while line:
            # Parse initial header name : value pair.
            try:
                bname, bvalue = line.split(b':', 1)
            except ValueError:
                raise InvalidHeader(line) from None

            bname = bname.strip(b' \t')
            bvalue = bvalue.lstrip()
            if HDRRE.search(bname):
                raise InvalidHeader(bname)
            if len(bname) > self.max_field_size:
                raise LineTooLong(
                    "request header name {}".format(
                        bname.decode("utf8", "xmlcharrefreplace")),
                    str(self.max_field_size),
                    str(len(bname)))

            header_length = len(bvalue)

            # next line
            lines_idx += 1
            line = lines[lines_idx]

            # consume continuation lines
            continuation = line and line[0] in (32, 9)  # (' ', '\t')

            if continuation:
                bvalue_lst = [bvalue]
                while continuation:
                    header_length += len(line)
                    if header_length > self.max_field_size:
                        raise LineTooLong(
                            'request header field {}'.format(
                                bname.decode("utf8", "xmlcharrefreplace")),
                            str(self.max_field_size),
                            str(header_length))
                    bvalue_lst.append(line)

                    # next line
                    lines_idx += 1
                    if lines_idx < line_count:
                        line = lines[lines_idx]
                        if line:
                            continuation = line[0] in (32, 9)  # (' ', '\t')
                    else:
                        line = b''
                        break
                bvalue = b''.join(bvalue_lst)
            else:
                if header_length > self.max_field_size:
                    raise LineTooLong(
                        'request header field {}'.format(
                            bname.decode("utf8", "xmlcharrefreplace")),
                        str(self.max_field_size),
                        str(header_length))

            bvalue = bvalue.strip()
            name = bname.decode('utf-8', 'surrogateescape')
            value = bvalue.decode('utf-8', 'surrogateescape')

            headers.add(name, value)
            raw_headers.append((bname, bvalue))

        return (CIMultiDictProxy(headers), tuple(raw_headers))
Exemplo n.º 37
0
class WsgiProtocol:
    """"Pure python WSGI protocol implementation
    """
    keep_alive = False
    parsed_url = None
    headers_sent = None
    status = None

    def __init__(self, protocol, cfg, FileWrapper):
        connection = protocol.connection
        server_address = connection.transport.get_extra_info('sockname')
        self.environ = {
            'wsgi.async': True,
            'wsgi.timestamp': protocol.producer.current_time,
            'wsgi.errors': sys.stderr,
            'wsgi.version': (1, 0),
            'wsgi.run_once': False,
            'wsgi.multithread': True,
            'wsgi.multiprocess': True,
            'SCRIPT_NAME': OS_SCRIPT_NAME,
            'SERVER_SOFTWARE': protocol.producer.server_software,
            'wsgi.file_wrapper': FileWrapper,
            'CONTENT_TYPE': '',
            'SERVER_NAME': server_address[0],
            'SERVER_PORT': str(server_address[1]),
            PULSAR_CACHE: protocol
        }
        self.body_reader = protocol.body_reader(self.environ)
        self.environ['wsgi.input'] = self.body_reader
        self.cfg = cfg
        self.headers = CIMultiDict()
        self.protocol = protocol
        self.connection = connection
        self.client_address = connection.address
        self.parser = protocol.create_parser(self)
        self.header_wsgi = Headers()

    def on_url(self, url):
        proto = self.protocol
        transport = self.connection.transport
        parsed_url = proto.parse_url(url)
        query = parsed_url.query or b''
        scheme = ('https'
                  if transport.get_extra_info('sslcontext') else URL_SCHEME)
        if parsed_url.schema:
            scheme = parsed_url.schema.decode(CHARSET)

        self.parsed_url = parsed_url
        self.environ.update(
            (('RAW_URI', url.decode(CHARSET)), ('RAW_URI',
                                                url.decode(CHARSET)),
             ('REQUEST_METHOD', self.parser.get_method().decode(CHARSET)),
             ('QUERY_STRING', query.decode(CHARSET)), ('wsgi.url_scheme',
                                                       scheme)))

    def on_header(self, name, value):
        header = istr(name.decode(CHARSET))
        header_value = value.decode(CHARSET)
        header_env = header.upper().replace('-', '_')

        if 'SERVER_PROTOCOL' not in self.environ:
            self.environ['SERVER_PROTOCOL'] = ("HTTP/%s" %
                                               self.parser.get_http_version())

        if header in HOP_HEADERS:
            if header == CONNECTION:
                if (self.environ['SERVER_PROTOCOL'] == 'HTTP/1.0'
                        or header_value.lower() != 'keep-alive'):
                    self.headers[header] = header_value
            else:
                self.headers[header] = header_value
        else:
            hnd = getattr(self.header_wsgi, header_env, None)
            if hnd and hnd(self.environ, header_value):
                return

        self.environ['HTTP_%s' % header_env] = header_value

    def on_headers_complete(self):
        if 'SERVER_PROTOCOL' not in self.environ:
            self.environ['SERVER_PROTOCOL'] = ("HTTP/%s" %
                                               self.parser.get_http_version())

        forward = self.headers.get(X_FORWARDED_FOR)
        client_address = self.client_address

        if self.environ.get('wsgi.url_scheme') == 'https':
            self.environ['HTTPS'] = 'on'
        if forward:
            # we only took the last one
            # http://en.wikipedia.org/wiki/X-Forwarded-For
            if forward.find(",") >= 0:
                forward = forward.rsplit(",", 1)[1].strip()
            client_address = forward.split(":")
            if len(client_address) < 2:
                client_address.append('80')
        self.environ['REMOTE_ADDR'] = client_address[0]
        self.environ['REMOTE_PORT'] = str(client_address[1])

        if self.parsed_url.path is not None:
            path_info = self.parsed_url.path.decode(CHARSET)
            script_name = self.environ['SCRIPT_NAME']
            if script_name:
                path_info = path_info.split(script_name, 1)[1]
            self.environ['PATH_INFO'] = unquote(path_info)

        # add the protocol to the pipeline
        self.connection.pipeline(self.protocol)

    def on_body(self, body):
        self.body_reader.feed_data(body)

    def on_message_complete(self):
        self.body_reader.feed_eof()
        self.protocol.finished_reading()

    def start_response(self, status, response_headers, exc_info=None):
        if exc_info:
            try:
                if self.headers_sent:
                    # if exc_info is provided, and the HTTP headers have
                    # already been sent, start_response must raise an error,
                    # and should re-raise using the exc_info tuple
                    reraise(exc_info[0], exc_info[1], exc_info[2])
            finally:
                # Avoid circular reference
                exc_info = None
        elif self.status:
            # Headers already set. Raise error
            raise RuntimeError("Response headers already set!")
        else:
            self.keep_alive = self.parser.should_keep_alive()
        self.status = status
        for header, value in response_headers:
            if header in HOP_HEADERS:
                # These features are the exclusive province of this class,
                # this should be considered a fatal error for an application
                # to attempt sending them, but we don't raise an error,
                # just log a warning
                self.protocol.producer.logger.warning(
                    'Application passing hop header "%s"', header)
                continue
            self.headers.add(header, value)
        producer = self.protocol.producer
        self.headers[SERVER] = producer.server_software
        self.headers[DATE] = fast_http_date(producer.current_time)
        return self.write

    def write(self, data, force=False):
        buffer = None
        env = self.environ
        proto = self.protocol

        if not self.headers_sent:
            self.headers_sent = self.get_headers()
            buffer = bytearray(
                ('%s %s\r\n' %
                 (env['SERVER_PROTOCOL'], self.status)).encode(CHARSET))
            for k, v in self.headers_sent.items():
                buffer.extend(('%s: %s\r\n' % (k, v)).encode(CHARSET))
            buffer.extend(CRLF)
            proto.event('on_headers').fire(data=buffer)

        if data:
            if not buffer:
                buffer = bytearray()
            if self.chunked:
                http_chunks(buffer, data)
            else:
                buffer.extend(data)

        elif force and self.chunked:
            if not buffer:
                buffer = bytearray()
            http_chunks(buffer, data, True)

        if buffer:
            return self.connection.write(buffer)

    def get_headers(self):
        headers = self.headers
        chunked = headers.get(TRANSFER_ENCODING) == 'chunked'
        content_length = CONTENT_LENGTH in headers
        status = int(self.status.split()[0])
        empty = has_empty_content(status, self.environ['REQUEST_METHOD'])

        if status >= 400:
            self.keep_alive = False

        if not self.status:
            # we are sending headers but the start_response was not called
            raise RuntimeError('Headers not set.')

        if (content_length or empty
                or self.environ['SERVER_PROTOCOL'] == 'HTTP/1.0'):
            chunked = False
            headers.pop(TRANSFER_ENCODING, None)
        elif not chunked and not content_length:
            chunked = True
            headers[TRANSFER_ENCODING] = 'chunked'

        if not self.keep_alive:
            headers[CONNECTION] = 'close'

        self.chunked = chunked
        return headers
Exemplo n.º 38
0
class HttpMessage(ABC):
    """HttpMessage allows to write headers and payload to a stream.

    For example, lets say we want to read file then compress it with deflate
    compression and then send it with chunked transfer encoding, code may look
    like this:

       >>> response = aiohttp.Response(transport, 200)

    We have to use deflate compression first:

      >>> response.add_compression_filter('deflate')

    Then we want to split output stream into chunks of 1024 bytes size:

      >>> response.add_chunking_filter(1024)

    We can add headers to response with add_headers() method. add_headers()
    does not send data to transport, send_headers() sends request/response
    line and then sends headers:

      >>> response.add_headers(
      ...     ('Content-Disposition', 'attachment; filename="..."'))
      >>> response.send_headers()

    Now we can use chunked writer to write stream to a network stream.
    First call to write() method sends response status line and headers,
    add_header() and add_headers() method unavailable at this stage:

    >>> with open('...', 'rb') as f:
    ...     chunk = fp.read(8192)
    ...     while chunk:
    ...         response.write(chunk)
    ...         chunk = fp.read(8192)

    >>> response.write_eof()

    """

    writer = None

    # 'filter' is being used for altering write() behaviour,
    # add_chunking_filter adds deflate/gzip compression and
    # add_compression_filter splits incoming data into a chunks.
    filter = None

    HOP_HEADERS = None  # Must be set by subclass.

    SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} aiohttp/{1}'.format(
        sys.version_info, aiohttp.__version__)

    upgrade = False  # Connection: UPGRADE
    websocket = False  # Upgrade: WEBSOCKET
    has_chunked_hdr = False  # Transfer-encoding: chunked

    # subclass can enable auto sending headers with write() call,
    # this is useful for wsgi's start_response implementation.
    _send_headers = False

    def __init__(self, transport, version, close):
        self.transport = transport
        self._version = version
        self.closing = close
        self.keepalive = None
        self.chunked = False
        self.length = None
        self.headers = CIMultiDict()
        self.headers_sent = False
        self.output_length = 0
        self.headers_length = 0
        self._output_size = 0
        self._cache = {}

    @property
    @abstractmethod
    def status_line(self):
        return b''

    @abstractmethod
    def autochunked(self):
        return False

    @property
    def version(self):
        return self._version

    @property
    def body_length(self):
        return self.output_length - self.headers_length

    def force_close(self):
        self.closing = True
        self.keepalive = False

    def enable_chunked_encoding(self):
        self.chunked = True

    def keep_alive(self):
        if self.keepalive is None:
            if self.version < HttpVersion10:
                # keep alive not supported at all
                return False
            if self.version == HttpVersion10:
                if self.headers.get(hdrs.CONNECTION) == 'keep-alive':
                    return True
                else:  # no headers means we close for Http 1.0
                    return False
            else:
                return not self.closing
        else:
            return self.keepalive

    def is_headers_sent(self):
        return self.headers_sent

    def add_header(self, name, value):
        """Analyze headers. Calculate content length,
        removes hop headers, etc."""
        assert not self.headers_sent, 'headers have been sent already'
        assert isinstance(name, str), \
            'Header name should be a string, got {!r}'.format(name)
        assert set(name).issubset(ASCIISET), \
            'Header name should contain ASCII chars, got {!r}'.format(name)
        assert isinstance(value, str), \
            'Header {!r} should have string value, got {!r}'.format(
                name, value)

        name = istr(name)
        value = value.strip()

        if name == hdrs.CONTENT_LENGTH:
            self.length = int(value)

        if name == hdrs.TRANSFER_ENCODING:
            self.has_chunked_hdr = value.lower().strip() == 'chunked'

        if name == hdrs.CONNECTION:
            val = value.lower()
            # handle websocket
            if 'upgrade' in val:
                self.upgrade = True
            # connection keep-alive
            elif 'close' in val:
                self.keepalive = False
            elif 'keep-alive' in val:
                self.keepalive = True

        elif name == hdrs.UPGRADE:
            if 'websocket' in value.lower():
                self.websocket = True
                self.headers[name] = value

        elif name not in self.HOP_HEADERS:
            # ignore hop-by-hop headers
            self.headers.add(name, value)

    def add_headers(self, *headers):
        """Adds headers to a HTTP message."""
        for name, value in headers:
            self.add_header(name, value)

    def send_headers(self, _sep=': ', _end='\r\n'):
        """Writes headers to a stream. Constructs payload writer."""
        # Chunked response is only for HTTP/1.1 clients or newer
        # and there is no Content-Length header is set.
        # Do not use chunked responses when the response is guaranteed to
        # not have a response body (304, 204).
        assert not self.headers_sent, 'headers have been sent already'
        self.headers_sent = True

        if self.chunked or self.autochunked():
            self.writer = self._write_chunked_payload()
            self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'

        elif self.length is not None:
            self.writer = self._write_length_payload(self.length)

        else:
            self.writer = self._write_eof_payload()

        next(self.writer)

        self._add_default_headers()

        # status + headers
        headers = self.status_line + ''.join(
            [k + _sep + v + _end for k, v in self.headers.items()])
        headers = headers.encode('utf-8') + b'\r\n'

        self.output_length += len(headers)
        self.headers_length = len(headers)
        self.transport.write(headers)

    def _add_default_headers(self):
        # set the connection header
        connection = None
        if self.upgrade:
            connection = 'Upgrade'
        elif not self.closing if self.keepalive is None else self.keepalive:
            if self.version == HttpVersion10:
                connection = 'keep-alive'
        else:
            if self.version == HttpVersion11:
                connection = 'close'

        if connection is not None:
            self.headers[hdrs.CONNECTION] = connection

    def write(self, chunk, *,
              drain=False, EOF_MARKER=EOF_MARKER, EOL_MARKER=EOL_MARKER):
        """Writes chunk of data to a stream by using different writers.

        writer uses filter to modify chunk of data.
        write_eof() indicates end of stream.
        writer can't be used after write_eof() method being called.
        write() return drain future.
        """
        assert (isinstance(chunk, (bytes, bytearray)) or
                chunk is EOF_MARKER), chunk

        size = self.output_length

        if self._send_headers and not self.headers_sent:
            self.send_headers()

        if self.filter:
            chunk = self.filter.send(chunk)
            while chunk not in (EOF_MARKER, EOL_MARKER):
                if chunk:
                    self.writer.send(chunk)
                chunk = next(self.filter)
        else:
            if chunk is not EOF_MARKER:
                self.writer.send(chunk)

        self._output_size += self.output_length - size

        if self._output_size > 64 * 1024:
            if drain:
                self._output_size = 0
                return self.transport.drain()

        return ()

    def write_eof(self):
        self.write(EOF_MARKER)
        try:
            self.writer.throw(aiohttp.EofStream())
        except StopIteration:
            pass

        return self.transport.drain()

    def _write_chunked_payload(self):
        """Write data in chunked transfer encoding."""
        while True:
            try:
                chunk = yield
            except aiohttp.EofStream:
                self.transport.write(b'0\r\n\r\n')
                self.output_length += 5
                break

            chunk = bytes(chunk)
            chunk_len = '{:x}\r\n'.format(len(chunk)).encode('ascii')
            self.transport.write(chunk_len + chunk + b'\r\n')
            self.output_length += len(chunk_len) + len(chunk) + 2

    def _write_length_payload(self, length):
        """Write specified number of bytes to a stream."""
        while True:
            try:
                chunk = yield
            except aiohttp.EofStream:
                break

            if length:
                l = len(chunk)
                if length >= l:
                    self.transport.write(chunk)
                    self.output_length += l
                    length = length-l
                else:
                    self.transport.write(chunk[:length])
                    self.output_length += length
                    length = 0

    def _write_eof_payload(self):
        while True:
            try:
                chunk = yield
            except aiohttp.EofStream:
                break

            self.transport.write(chunk)
            self.output_length += len(chunk)

    @wrap_payload_filter
    def add_chunking_filter(self, chunk_size=16*1024, *,
                            EOF_MARKER=EOF_MARKER, EOL_MARKER=EOL_MARKER):
        """Split incoming stream into chunks."""
        buf = bytearray()
        chunk = yield

        while True:
            if chunk is EOF_MARKER:
                if buf:
                    yield buf

                yield EOF_MARKER

            else:
                buf.extend(chunk)

                while len(buf) >= chunk_size:
                    chunk = bytes(buf[:chunk_size])
                    del buf[:chunk_size]
                    yield chunk

                chunk = yield EOL_MARKER

    @wrap_payload_filter
    def add_compression_filter(self, encoding='deflate', *,
                               EOF_MARKER=EOF_MARKER, EOL_MARKER=EOL_MARKER):
        """Compress incoming stream with deflate or gzip encoding."""
        zlib_mode = (16 + zlib.MAX_WBITS
                     if encoding == 'gzip' else -zlib.MAX_WBITS)
        zcomp = zlib.compressobj(wbits=zlib_mode)

        chunk = yield
        while True:
            if chunk is EOF_MARKER:
                yield zcomp.flush()
                chunk = yield EOF_MARKER

            else:
                yield zcomp.compress(chunk)
                chunk = yield EOL_MARKER
Exemplo n.º 39
0
    def parse_headers(self, lines):
        """Parses RFC 5322 headers from a stream.

        Line continuations are supported. Returns list of header name
        and value pairs. Header name is in upper case.
        """
        headers = CIMultiDict()
        raw_headers = []

        lines_idx = 1
        line = lines[1]
        line_count = len(lines)

        while line:
            header_length = len(line)

            # Parse initial header name : value pair.
            try:
                bname, bvalue = line.split(b':', 1)
            except ValueError:
                raise InvalidHeader(line) from None

            bname = bname.strip(b' \t')
            if HDRRE.search(bname):
                raise InvalidHeader(bname)

            # next line
            lines_idx += 1
            line = lines[lines_idx]

            # consume continuation lines
            continuation = line and line[0] in (32, 9)  # (' ', '\t')

            if continuation:
                bvalue = [bvalue]
                while continuation:
                    header_length += len(line)
                    if header_length > self.max_field_size:
                        raise LineTooLong(
                            'request header field {}'.format(
                                bname.decode("utf8", "xmlcharrefreplace")),
                            self.max_field_size)
                    bvalue.append(line)

                    # next line
                    lines_idx += 1
                    if lines_idx < line_count:
                        line = lines[lines_idx]
                        if line:
                            continuation = line[0] in (32, 9)  # (' ', '\t')
                    else:
                        line = b''
                        break
                bvalue = b''.join(bvalue)
            else:
                if header_length > self.max_field_size:
                    raise LineTooLong(
                        'request header field {}'.format(
                            bname.decode("utf8", "xmlcharrefreplace")),
                        self.max_field_size)

            bvalue = bvalue.strip()
            name = bname.decode('utf-8', 'surrogateescape')
            value = bvalue.decode('utf-8', 'surrogateescape')

            headers.add(name, value)
            raw_headers.append((bname, bvalue))

        close_conn = None
        encoding = None
        upgrade = False
        chunked = False
        raw_headers = tuple(raw_headers)

        # keep-alive
        conn = headers.get(hdrs.CONNECTION)
        if conn:
            v = conn.lower()
            if v == 'close':
                close_conn = True
            elif v == 'keep-alive':
                close_conn = False
            elif v == 'upgrade':
                upgrade = True

        # encoding
        enc = headers.get(hdrs.CONTENT_ENCODING)
        if enc:
            enc = enc.lower()
            if enc in ('gzip', 'deflate'):
                encoding = enc

        # chunking
        te = headers.get(hdrs.TRANSFER_ENCODING)
        if te and 'chunked' in te.lower():
            chunked = True

        return headers, raw_headers, close_conn, encoding, upgrade, chunked
Exemplo n.º 40
0
    def parse_headers(self, lines):
        """Parses RFC 5322 headers from a stream.

        Line continuations are supported. Returns list of header name
        and value pairs. Header name is in upper case.
        """
        close_conn = None
        encoding = None
        headers = CIMultiDict()
        raw_headers = []

        lines_idx = 1
        line = lines[1]

        while line:
            header_length = len(line)

            # Parse initial header name : value pair.
            try:
                bname, bvalue = line.split(b':', 1)
            except ValueError:
                raise errors.InvalidHeader(line) from None

            bname = bname.strip(b' \t').upper()
            if HDRRE.search(bname):
                raise errors.InvalidHeader(bname)

            # next line
            lines_idx += 1
            line = lines[lines_idx]

            # consume continuation lines
            continuation = line and line[0] in (32, 9)  # (' ', '\t')

            if continuation:
                bvalue = [bvalue]
                while continuation:
                    header_length += len(line)
                    if header_length > self.max_field_size:
                        raise errors.LineTooLong(
                            'request header field {}'.format(
                                bname.decode("utf8", "xmlcharrefreplace")),
                            self.max_field_size)
                    bvalue.append(line)

                    # next line
                    lines_idx += 1
                    line = lines[lines_idx]
                    continuation = line[0] in (32, 9)  # (' ', '\t')
                bvalue = b'\r\n'.join(bvalue)
            else:
                if header_length > self.max_field_size:
                    raise errors.LineTooLong(
                        'request header field {}'.format(
                            bname.decode("utf8", "xmlcharrefreplace")),
                        self.max_field_size)

            bvalue = bvalue.strip()

            name = istr(bname.decode('utf-8', 'surrogateescape'))
            value = bvalue.decode('utf-8', 'surrogateescape')

            # keep-alive and encoding
            if name == hdrs.CONNECTION:
                v = value.lower()
                if v == 'close':
                    close_conn = True
                elif v == 'keep-alive':
                    close_conn = False
            elif name == hdrs.CONTENT_ENCODING:
                enc = value.lower()
                if enc in ('gzip', 'deflate'):
                    encoding = enc

            headers.add(name, value)
            raw_headers.append((bname, bvalue))

        return headers, raw_headers, close_conn, encoding
Exemplo n.º 41
0
class ClientRequest:
    GET_METHODS = {
        hdrs.METH_GET,
        hdrs.METH_HEAD,
        hdrs.METH_OPTIONS,
        hdrs.METH_TRACE,
    }
    POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    ALL_METHODS = GET_METHODS.union(POST_METHODS).union({hdrs.METH_DELETE})

    DEFAULT_HEADERS = {
        hdrs.ACCEPT: '*/*',
        hdrs.ACCEPT_ENCODING: 'gzip, deflate',
    }

    body = b''
    auth = None
    response = None
    response_class = None

    _writer = None  # async task for streaming data
    _continue = None  # waiter future for '100 Continue' response

    # N.B.
    # Adding __del__ method with self._writer closing doesn't make sense
    # because _writer is instance method, thus it keeps a reference to self.
    # Until writer has finished finalizer will not be called.

    def __init__(self,
                 method: str,
                 url: URL,
                 *,
                 params: Optional[Mapping[str, str]] = None,
                 headers: Optional[LooseHeaders] = None,
                 skip_auto_headers: Iterable[str] = frozenset(),
                 data: Any = None,
                 cookies: Optional[LooseCookies] = None,
                 auth: Optional[BasicAuth] = None,
                 version: http.HttpVersion = http.HttpVersion11,
                 compress: Optional[str] = None,
                 chunked: Optional[bool] = None,
                 expect100: bool = False,
                 loop: Optional[asyncio.AbstractEventLoop] = None,
                 response_class: Optional[Type['ClientResponse']] = None,
                 proxy: Optional[URL] = None,
                 proxy_auth: Optional[BasicAuth] = None,
                 timer: Optional[BaseTimerContext] = None,
                 session: Optional['ClientSession'] = None,
                 ssl: Union[SSLContext, bool, Fingerprint, None] = None,
                 proxy_headers: Optional[LooseHeaders] = None,
                 traces: Optional[List['Trace']] = None):

        if loop is None:
            loop = asyncio.get_event_loop()

        assert isinstance(url, URL), url
        assert isinstance(proxy, (URL, type(None))), proxy
        # FIXME: session is None in tests only, need to fix tests
        # assert session is not None
        self._session = cast('ClientSession', session)
        if params:
            q = MultiDict(url.query)
            url2 = url.with_query(params)
            q.extend(url2.query)
            url = url.with_query(q)
        self.original_url = url
        self.url = url.with_fragment(None)
        self.method = method.upper()
        self.chunked = chunked
        self.compress = compress
        self.loop = loop
        self.length = None
        if response_class is None:
            real_response_class = ClientResponse
        else:
            real_response_class = response_class
        self.response_class = real_response_class  # type: Type[ClientResponse]
        self._timer = timer if timer is not None else TimerNoop()
        self._ssl = ssl

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        self.update_version(version)
        self.update_host(url)
        self.update_headers(headers)
        self.update_auto_headers(skip_auto_headers)
        self.update_cookies(cookies)
        self.update_content_encoding(data)
        self.update_auth(auth)
        self.update_proxy(proxy, proxy_auth, proxy_headers)

        self.update_body_from_data(data)
        if data or self.method not in self.GET_METHODS:
            self.update_transfer_encoding()
        self.update_expect_continue(expect100)
        if traces is None:
            traces = []
        self._traces = traces

    def is_ssl(self) -> bool:
        return self.url.scheme in ('https', 'wss')

    @property
    def ssl(self) -> Union['SSLContext', None, bool, Fingerprint]:
        return self._ssl

    @property
    def connection_key(self) -> ConnectionKey:
        proxy_headers = self.proxy_headers
        if proxy_headers:
            h = hash(tuple(
                (k, v) for k, v in
                proxy_headers.items()))  # type: Optional[int]  # noqa
        else:
            h = None
        return ConnectionKey(self.host, self.port, self.is_ssl(), self.ssl,
                             self.proxy, self.proxy_auth, h)

    @property
    def host(self) -> str:
        ret = self.url.host
        assert ret is not None
        return ret

    @property
    def port(self) -> Optional[int]:
        return self.url.port

    @property
    def request_info(self) -> RequestInfo:
        headers = CIMultiDictProxy(self.headers)  # type: CIMultiDictProxy[str]
        return RequestInfo(self.url, self.method, headers, self.original_url)

    def update_host(self, url: URL) -> None:
        """Update destination host, port and connection type (ssl)."""
        # get host/port
        if not url.host:
            raise InvalidURL(url)

        # basic auth info
        username, password = url.user, url.password
        if username:
            self.auth = helpers.BasicAuth(username, password or '')

    def update_version(self, version: Union[http.HttpVersion, str]) -> None:
        """Convert request version to two elements tuple.

        parser HTTP version '1.1' => (1, 1)
        """
        if isinstance(version, str):
            v = [l.strip() for l in version.split('.', 1)]
            try:
                version = http.HttpVersion(int(v[0]), int(v[1]))
            except ValueError:
                raise ValueError(
                    'Can not parse http version number: {}'.format(
                        version)) from None
        self.version = version

    def update_headers(self, headers: Optional[LooseHeaders]) -> None:
        """Update request headers."""
        self.headers = CIMultiDict()  # type: CIMultiDict[str]

        # add host
        netloc = cast(str, self.url.raw_host)
        if helpers.is_ipv6_address(netloc):
            netloc = '[{}]'.format(netloc)
        if not self.url.is_default_port():
            netloc += ':' + str(self.url.port)
        self.headers[hdrs.HOST] = netloc

        if headers:
            if isinstance(headers, (dict, MultiDictProxy, MultiDict)):
                headers = headers.items()  # type: ignore

            for key, value in headers:
                # A special case for Host header
                if key.lower() == 'host':
                    self.headers[key] = value
                else:
                    self.headers.add(key, value)

    def update_auto_headers(self, skip_auto_headers: Iterable[str]) -> None:
        self.skip_auto_headers = CIMultiDict(
            (hdr, None) for hdr in sorted(skip_auto_headers))
        used_headers = self.headers.copy()
        used_headers.extend(self.skip_auto_headers)  # type: ignore

        for hdr, val in self.DEFAULT_HEADERS.items():
            if hdr not in used_headers:
                self.headers.add(hdr, val)

        if hdrs.USER_AGENT not in used_headers:
            self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE

    def update_cookies(self, cookies: Optional[LooseCookies]) -> None:
        """Update request cookies header."""
        if not cookies:
            return

        c = SimpleCookie()
        if hdrs.COOKIE in self.headers:
            c.load(self.headers.get(hdrs.COOKIE, ''))
            del self.headers[hdrs.COOKIE]

        if isinstance(cookies, Mapping):
            iter_cookies = cookies.items()
        else:
            iter_cookies = cookies  # type: ignore
        for name, value in iter_cookies:
            if isinstance(value, Morsel):
                # Preserve coded_value
                mrsl_val = value.get(value.key, Morsel())
                mrsl_val.set(value.key, value.value,
                             value.coded_value)  # type: ignore  # noqa
                c[name] = mrsl_val
            else:
                c[name] = value  # type: ignore

        self.headers[hdrs.COOKIE] = c.output(header='', sep=';').strip()

    def update_content_encoding(self, data: Any) -> None:
        """Set request content encoding."""
        if not data:
            return

        enc = self.headers.get(hdrs.CONTENT_ENCODING, '').lower()
        if enc:
            if self.compress:
                raise ValueError('compress can not be set '
                                 'if Content-Encoding header is set')
        elif self.compress:
            if not isinstance(self.compress, str):
                self.compress = 'deflate'
            self.headers[hdrs.CONTENT_ENCODING] = self.compress
            self.chunked = True  # enable chunked, no need to deal with length

    def update_transfer_encoding(self) -> None:
        """Analyze transfer-encoding header."""
        te = self.headers.get(hdrs.TRANSFER_ENCODING, '').lower()

        if 'chunked' in te:
            if self.chunked:
                raise ValueError(
                    'chunked can not be set '
                    'if "Transfer-Encoding: chunked" header is set')

        elif self.chunked:
            if hdrs.CONTENT_LENGTH in self.headers:
                raise ValueError('chunked can not be set '
                                 'if Content-Length header is set')

            self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'
        else:
            if hdrs.CONTENT_LENGTH not in self.headers:
                self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_auth(self, auth: Optional[BasicAuth]) -> None:
        """Set basic auth."""
        if auth is None:
            auth = self.auth
        if auth is None:
            return

        if not isinstance(auth, helpers.BasicAuth):
            raise TypeError('BasicAuth() tuple is required instead')

        self.headers[hdrs.AUTHORIZATION] = auth.encode()

    def update_body_from_data(self, body: Any) -> None:
        if not body:
            return

        # FormData
        if isinstance(body, FormData):
            body = body()

        try:
            body = payload.PAYLOAD_REGISTRY.get(body, disposition=None)
        except payload.LookupError:
            body = FormData(body)()

        self.body = body

        # enable chunked encoding if needed
        if not self.chunked:
            if hdrs.CONTENT_LENGTH not in self.headers:
                size = body.size
                if size is None:
                    self.chunked = True
                else:
                    if hdrs.CONTENT_LENGTH not in self.headers:
                        self.headers[hdrs.CONTENT_LENGTH] = str(size)

        # set content-type
        if (hdrs.CONTENT_TYPE not in self.headers
                and hdrs.CONTENT_TYPE not in self.skip_auto_headers):
            self.headers[hdrs.CONTENT_TYPE] = body.content_type

        # copy payload headers
        if body.headers:
            for (key, value) in body.headers.items():
                if key not in self.headers:
                    self.headers[key] = value

    def update_expect_continue(self, expect: bool = False) -> None:
        if expect:
            self.headers[hdrs.EXPECT] = '100-continue'
        elif self.headers.get(hdrs.EXPECT, '').lower() == '100-continue':
            expect = True

        if expect:
            self._continue = self.loop.create_future()

    def update_proxy(self, proxy: Optional[URL],
                     proxy_auth: Optional[BasicAuth],
                     proxy_headers: Optional[LooseHeaders]) -> None:
        if proxy and not proxy.scheme == 'http':
            raise ValueError("Only http proxies are supported")
        if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
            raise ValueError("proxy_auth must be None or BasicAuth() tuple")
        self.proxy = proxy
        self.proxy_auth = proxy_auth
        self.proxy_headers = proxy_headers

    def keep_alive(self) -> bool:
        if self.version < HttpVersion10:
            # keep alive not supported at all
            return False
        if self.version == HttpVersion10:
            if self.headers.get(hdrs.CONNECTION) == 'keep-alive':
                return True
            else:  # no headers means we close for Http 1.0
                return False
        elif self.headers.get(hdrs.CONNECTION) == 'close':
            return False

        return True

    async def write_bytes(self, writer: AbstractStreamWriter,
                          conn: 'Connection') -> None:
        """Support coroutines that yields bytes objects."""
        # 100 response
        if self._continue is not None:
            await writer.drain()
            await self._continue

        protocol = conn.protocol
        assert protocol is not None
        try:
            if isinstance(self.body, payload.Payload):
                await self.body.write(writer)
            else:
                if isinstance(self.body, (bytes, bytearray)):
                    self.body = (self.body, )  # type: ignore

                for chunk in self.body:
                    await writer.write(chunk)  # type: ignore

            await writer.write_eof()
        except OSError as exc:
            new_exc = ClientOSError(
                exc.errno, 'Can not write request body for %s' % self.url)
            new_exc.__context__ = exc
            new_exc.__cause__ = exc
            protocol.set_exception(new_exc)
        except asyncio.CancelledError as exc:
            if not conn.closed:
                protocol.set_exception(exc)
        except Exception as exc:
            protocol.set_exception(exc)
        finally:
            self._writer = None

    async def send(self, conn: 'Connection') -> 'ClientResponse':
        # Specify request target:
        # - CONNECT request must send authority form URI
        # - not CONNECT proxy must send absolute form URI
        # - most common is origin form URI
        if self.method == hdrs.METH_CONNECT:
            path = '{}:{}'.format(self.url.raw_host, self.url.port)
        elif self.proxy and not self.is_ssl():
            path = str(self.url)
        else:
            path = self.url.raw_path
            if self.url.raw_query_string:
                path += '?' + self.url.raw_query_string

        protocol = conn.protocol
        assert protocol is not None
        writer = StreamWriter(protocol,
                              self.loop,
                              on_chunk_sent=self._on_chunk_request_sent)

        if self.compress:
            writer.enable_compression(self.compress)

        if self.chunked is not None:
            writer.enable_chunking()

        # set default content-type
        if (self.method in self.POST_METHODS
                and hdrs.CONTENT_TYPE not in self.skip_auto_headers
                and hdrs.CONTENT_TYPE not in self.headers):
            self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'

        # set the connection header
        connection = self.headers.get(hdrs.CONNECTION)
        if not connection:
            if self.keep_alive():
                if self.version == HttpVersion10:
                    connection = 'keep-alive'
            else:
                if self.version == HttpVersion11:
                    connection = 'close'

        if connection is not None:
            self.headers[hdrs.CONNECTION] = connection

        # status + headers
        status_line = '{0} {1} HTTP/{2[0]}.{2[1]}'.format(
            self.method, path, self.version)
        await writer.write_headers(status_line, self.headers)

        self._writer = self.loop.create_task(self.write_bytes(writer, conn))

        response_class = self.response_class
        assert response_class is not None
        self.response = response_class(self.method,
                                       self.original_url,
                                       writer=self._writer,
                                       continue100=self._continue,
                                       timer=self._timer,
                                       request_info=self.request_info,
                                       traces=self._traces,
                                       loop=self.loop,
                                       session=self._session)
        return self.response

    async def close(self) -> None:
        if self._writer is not None:
            try:
                await self._writer
            finally:
                self._writer = None

    def terminate(self) -> None:
        if self._writer is not None:
            if not self.loop.is_closed():
                self._writer.cancel()
            self._writer = None

    async def _on_chunk_request_sent(self, chunk: bytes) -> None:
        for trace in self._traces:
            await trace.send_request_chunk_sent(chunk)
Exemplo n.º 42
0
class Request(object):
    def on_message_begin(self):
        pass

    def on_url(self, url: bytes):
        self.raw_path += url

    def on_header(self, name: bytes, value: bytes):
        str_name = name.decode('ascii')
        str_value = value.decode('ascii')

        self.raw_headers.add(str_name, value)
        self.headers.add(str_name, str_value)

    def on_headers_complete(self):
        parsed_path = parse_url(self.raw_path)

        self.version = self._parser.get_http_version()
        self.keep_alive = self._parser.should_keep_alive()
        self.upgrade = self._parser.should_upgrade()
        self.raw_method = self._parser.get_method()
        self.method = self.raw_method.decode('ascii')
        self.path = parsed_path.path.decode('ascii')

        for name, values in parse_qs(parsed_path.query,).items():
            for value in values:
                self.raw_query.add(name.decode('ascii'), value)
                self.query.add(name.decode('ascii'), value.decode('ascii'))

        self.headers._post_process(self)
        self._headers_complete = True

        if self.headers.content_length:
            self._body_length = self.headers.content_length

    def on_body(self, body: bytes):
        self._body_buffer += body

    def on_message_complete(self):
        self._is_body_complete = True

    def on_chunk_header(self):
        pass

    def on_chunk_complete(self):
        pass

    async def _try_read_headers(self):
        while not self._headers_complete:
            data = await self._connection.read_request(64 * 1024)

            if not data:
                return False

            self._parser.feed_data(data)

        return True

    async def _read_body(self, max_length=64*1024):
        while (not self._is_body_complete) or (len(self._body_buffer) > 0):
            chunk_length = min(max_length, self._body_length - (self._body_position + len(self._body_buffer))) 
            data = await self._connection.read_request(chunk_length)

            if not data:
                self._is_body_complete = True
            else:
                self._body_position += len(data)
                self._parser.feed_data(data)

            if self._body_position == self._body_length:
                self._is_body_complete = True

            if len(self._body_buffer) > 0:
                result = self._body_buffer
                self._body_buffer = b''

                return result

        return b''

    def __init__(self, connection):
        self._connection = connection
        self._parser = HttpRequestParser(self)
        self._body_buffer = b''
        self._headers_complete = False
        self._is_body_complete = False

        self.version = None
        self.keep_alive = None
        self.upgrade = None
        self.address = connection.address
        self.raw_method = b''
        self.raw_headers = CIMultiDict()
        self.raw_query = CIMultiDict()
        self.raw_path = b''

        self.method = ''
        self.host = 'localhost'
        self.port = 80
        self.headers = RequestHeaders()
        self.query = CIMultiDict()
        self.cookies = {}
        self.path = ''
        self.content_type_main = 'application'
        self.content_type_sub = 'octet-stream'
        self.content_type_params = {}
        self.content_charset = 'ascii'

        self._body_length = 2**32
        self._body_position = 0
        self._is_body_complete = False
        self._body = None
        self._text = None
        self._json = None
        self._form = None

    async def read_body(self):
        if self._body is None:
            async with self.open_body() as stream:
                self._body = await stream.readall()

        return self._body

    async def read_text(self):
        if self._text is None:
            charset = self.content_type_params.get('charset', 'ascii')

            self._text = self.read_body().decode(charset)

        return self._text

    async def read_json(self):
        if self._json is None:
            self._json = loads(self.read_text())

        return self._json

    async def read_form(self):
        if self._form is None:
            self._form = CIMultiDict()

            if (self.headers.content_type.type == 'application') and (self.headers.content_type.subtype == 'x-www-form-urlencoded'):
                body = await self.read_body()

                for parameter in body.split(b'&'):
                    name, value = parameter.split(b'=')
                    self._form.add(
                        unquote_to_bytes(name).decode('utf-8'),
                        unquote_to_bytes(value).decode('utf-8'))
            elif (self.headers.content_type.type == 'multipart') and (self.headers.content_type.subtype == 'form-data'):
                # TODO: replace with multifruits
                parser = BytesFeedParser(policy=HTTP, _factory=message_factory)
                parser.feed(b'Content-Type: %s\r\n\r\n' % self.raw_headers['Content-Type'])

                async with self.open_body() as stream:
                    while True:
                        data = await stream.read()

                        if not data:
                            break

                        parser.feed(data)

                message = parser.close()

                for part in message.walk():
                    if part.get_content_type() != 'multipart/form-data':
                        params = dict(part.get_params(header='content-disposition'))
                        name = params.get('name')

                        if name:
                            payload =  part.get_payload(decode=True)

                            if payload:
                                if part.get_content_type() == 'application/form-data':
                                    self._form.add(name, payload.decode('utf-8'))
                                else:
                                    self._form.add(
                                        name,
                                        RequestFilePart(params.get('filename'), part.items(), part.get_payload(decode=True)))

        return self._form

    def open_body(self):
        return RequestBodyStreamContext(self)