示例#1
0
class Protocol:

    def __init__(self, Parser, **kwargs):
        self.url = None
        self.headers = CIMultiDict()
        self.body = b''
        self.headers_complete = False
        self.message_complete = False
        self.parser = Parser(self, **kwargs)
        self.feed_data = self.parser.feed_data

    def on_url(self, url):
        self.url = url

    def on_header(self, name, value):
        self.headers.add(name.decode(), value.decode())

    def on_headers_complete(self):
        self.headers_complete = True

    def on_body(self, body):
        self.body += body

    def on_message_complete(self):
        self.message_complete = True
示例#2
0
文件: message.py 项目: sangoma/aiosip
    def from_raw_headers(cls, raw_headers):
        headers = CIMultiDict()
        decoded_headers = raw_headers.decode().split(utils.EOL)
        for line in decoded_headers[1:]:
            k, v = line.split(': ', 1)
            if k in headers:
                o = headers.setdefault(k, [])
                if not isinstance(o, list):
                    o = [o]
                o.append(v)
                headers[k] = o
            else:
                headers[k] = v

        m = FIRST_LINE_PATTERN['response']['regex'].match(decoded_headers[0])
        if m:
            d = m.groupdict()
            return Response(status_code=int(d['status_code']),
                            status_message=d['status_message'],
                            headers=headers,
                            first_line=decoded_headers[0])
        else:
            m = FIRST_LINE_PATTERN['request']['regex'].match(decoded_headers[0])
            if m:
                d = m.groupdict()
                cseq, _ = headers['CSeq'].split()
                return Request(method=d['method'],
                               headers=headers,
                               cseq=int(cseq),
                               first_line=decoded_headers[0])
            else:
                LOG.debug(decoded_headers)
                raise ValueError('Not a SIP message')
示例#3
0
class HttpTunnel(RequestBase):
    first_line = None
    data = None
    decompress = False
    method = 'CONNECT'

    def __init__(self, client, req):
        self.client = client
        self.key = req
        self.headers = CIMultiDict(client.DEFAULT_TUNNEL_HEADERS)

    def __repr__(self):
        return 'Tunnel %s' % self.url
    __str__ = __repr__

    def encode(self):
        self.headers['host'] = self.key.netloc
        self.first_line = 'CONNECT http://%s:%s HTTP/1.1' % self.key.address
        buffer = [self.first_line.encode('ascii'), b'\r\n']
        buffer.extend((('%s: %s\r\n' % (name, value)).encode(CHARSET)
                       for name, value in self.headers.items()))
        buffer.append(b'\r\n')
        return b''.join(buffer)

    def has_header(self, header_name):
        return header_name in self.headers

    def get_header(self, header_name, default=None):
        return self.headers.get(header_name, default)

    def remove_header(self, header_name):
        self.headers.pop(header_name, None)
示例#4
0
 def get_headers(self, request, headers):
     # Returns a :class:`Header` obtained from combining
     # :attr:`headers` with *headers*. Can handle websocket requests.
     # TODO: this is a buf in CIMultiDict
     # d = self.headers.copy()
     d = CIMultiDict(self.headers.items())
     if headers:
         d.update(headers)
     return d
示例#5
0
 def __init__(self, proxies=None, headers=None, verify=True,
              cookies=None, store_cookies=True, cert=None,
              max_redirects=10, decompress=True, version=None,
              websocket_handler=None, parser=None, trust_env=True,
              loop=None, client_version=None, timeout=None, stream=False,
              pool_size=10, frame_parser=None, logger=None,
              close_connections=False, keep_alive=None):
     super().__init__(
         partial(Connection, HttpResponse),
         loop=loop,
         keep_alive=keep_alive or cfg_value('http_keep_alive')
     )
     self.logger = logger or LOGGER
     self.client_version = client_version or self.client_version
     self.connection_pools = {}
     self.pool_size = pool_size
     self.trust_env = trust_env
     self.timeout = timeout
     self.store_cookies = store_cookies
     self.max_redirects = max_redirects
     self.cookies = cookiejar_from_dict(cookies)
     self.decompress = decompress
     self.version = version or self.version
     # SSL Verification default
     self.verify = verify
     # SSL client certificate default, if String, path to ssl client
     # cert file (.pem). If Tuple, ('cert', 'key') pair
     self.cert = cert
     self.stream = stream
     self.close_connections = close_connections
     dheaders = CIMultiDict(self.DEFAULT_HTTP_HEADERS)
     dheaders['user-agent'] = self.client_version
     # override headers
     if headers:
         for name, value in mapping_iterator(headers):
             if value is None:
                 dheaders.pop(name, None)
             else:
                 dheaders[name] = value
     self.headers = dheaders
     self.proxies = dict(proxies or ())
     if not self.proxies and self.trust_env:
         self.proxies = get_environ_proxies()
         if 'no' not in self.proxies:
             self.proxies['no'] = ','.join(self.no_proxy)
     self.websocket_handler = websocket_handler
     self.http_parser = parser or http.HttpResponseParser
     self.frame_parser = frame_parser or websocket.frame_parser
     # Add hooks
     self.event('on_headers').bind(handle_cookies)
     self.event('pre_request').bind(WebSocket())
     self.event('post_request').bind(Expect())
     self.event('post_request').bind(Redirect())
     self._decompressors = dict(
         gzip=GzipDecompress(),
         deflate=DeflateDecompress()
     )
示例#6
0
def parse_headers(header_data: bytes, value_encoding: str = 'ascii') -> CIMultiDict:
    assert check_argument_types()
    headers = CIMultiDict()
    for line in header_data.rstrip().split(b'\r\n'):
        key, value = line.split(b':', 1)
        key = key.strip().decode('ascii')
        value = value.strip().decode(value_encoding)
        headers.add(key, value)

    return headers
示例#7
0
def test_multiple_forwarded_headers_injection():
    headers = CIMultiDict()
    # This could be sent by an attacker, hoping to "shadow" the second header.
    headers.add('Forwarded', 'for=_injected;by="')
    # This is added by our trusted reverse proxy.
    headers.add('Forwarded', 'for=_real;by=_actual_proxy')
    req = make_mocked_request('GET', '/', headers=headers)
    assert len(req.forwarded) == 2
    assert 'by' not in req.forwarded[0]
    assert req.forwarded[1]['for'] == '_real'
    assert req.forwarded[1]['by'] == '_actual_proxy'
示例#8
0
文件: drivers.py 项目: Fahreeve/aiovk
 async def start(self, connection, read_until_eof=False):
     # vk.com return url like this: http://REDIRECT_URI#access_token=...
     # but aiohttp by default removes all parameters after '#'
     await super().start(connection, read_until_eof)
     headers = CIMultiDict(self.headers)
     location = headers.get(hdrs.LOCATION, None)
     if location:
         headers[hdrs.LOCATION] = location.replace('#', '?')
     self.headers = CIMultiDictProxy(headers)
     self.raw_headers = tuple(headers.items())
     return self
示例#9
0
def test_multiple_forwarded_headers():
    headers = CIMultiDict()
    headers.add('Forwarded', 'By=identifier1;for=identifier2, BY=identifier3')
    headers.add('Forwarded', 'By=identifier4;fOr=identifier5')
    req = make_mocked_request('GET', '/', headers=headers)
    assert len(req.forwarded) == 3
    assert req.forwarded[0]['by'] == 'identifier1'
    assert req.forwarded[0]['for'] == 'identifier2'
    assert req.forwarded[1]['by'] == 'identifier3'
    assert req.forwarded[2]['by'] == 'identifier4'
    assert req.forwarded[2]['for'] == 'identifier5'
示例#10
0
def test_multiple_forwarded_headers_bad_syntax():
    headers = CIMultiDict()
    headers.add('Forwarded', 'for=_1;by=_2')
    headers.add('Forwarded', 'invalid value')
    headers.add('Forwarded', '')
    headers.add('Forwarded', 'for=_3;by=_4')
    req = make_mocked_request('GET', '/', headers=headers)
    assert len(req.forwarded) == 4
    assert req.forwarded[0]['for'] == '_1'
    assert 'for' not in req.forwarded[1]
    assert 'for' not in req.forwarded[2]
    assert req.forwarded[3]['by'] == '_4'
示例#11
0
 def __init__(self,
              value: Any,
              headers: Optional[
                  Union[
                      _CIMultiDict,
                      Dict[str, str],
                      Iterable[Tuple[str, str]]
                  ]
              ] = None,
              content_type: Optional[str]=sentinel,
              filename: Optional[str]=None,
              encoding: Optional[str]=None,
              **kwargs: Any) -> None:
     self._encoding = encoding
     self._filename = filename
     self._headers = CIMultiDict()  # type: _CIMultiDict
     self._value = value
     if content_type is not sentinel and content_type is not None:
         self._headers[hdrs.CONTENT_TYPE] = content_type
     elif self._filename is not None:
         content_type = mimetypes.guess_type(self._filename)[0]
         if content_type is None:
             content_type = self._default_content_type
         self._headers[hdrs.CONTENT_TYPE] = content_type
     else:
         self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
     self._headers.update(headers or {})
示例#12
0
文件: message.py 项目: sangoma/aiosip
    def __init__(self,
                 headers=None,
                 payload=None,
                 from_details=None,
                 to_details=None,
                 contact_details=None,
                 ):

        if headers:
            self.headers = headers
        else:
            self.headers = CIMultiDict()

        if from_details:
            self._from_details = from_details
        elif 'From' not in self.headers:
            raise ValueError('From header or from_details is required')

        if to_details:
            self._to_details = to_details
        elif 'To' not in self.headers:
            raise ValueError('To header or to_details is required')

        if contact_details:
            self._contact_details = contact_details

        self._payload = payload
        self._raw_payload = None

        if 'Via' not in self.headers:
            self.headers['Via'] = 'SIP/2.0/%(protocol)s ' + \
                utils.format_host_and_port(self.contact_details['uri']['host'],
                                           self.contact_details['uri']['port']) + \
                ';branch=%s' % utils.gen_branch(10)
示例#13
0
文件: client.py 项目: cynecx/aiohttp
 def _prepare_headers(self, headers):
     """ Add default headers and transform it to CIMultiDict
     """
     # Convert headers to MultiDict
     result = CIMultiDict(self._default_headers)
     if headers:
         if not isinstance(headers, (MultiDictProxy, MultiDict)):
             headers = CIMultiDict(headers)
         added_names = set()
         for key, value in headers.items():
             if key in added_names:
                 result.add(key, value)
             else:
                 result[key] = value
                 added_names.add(key)
     return result
示例#14
0
 def __init__(self, client, url, method, inp_params=None, headers=None,
              data=None, files=None, json=None, history=None, auth=None,
              charset=None, max_redirects=10, source_address=None,
              allow_redirects=False, decompress=True, version=None,
              wait_continue=False, websocket_handler=None, cookies=None,
              params=None, stream=False, proxies=None, verify=True,
              cert=None, **extra):
     self.client = client
     self.method = method.upper()
     self.inp_params = inp_params or {}
     self.unredirected_headers = CIMultiDict()
     self.history = history
     self.wait_continue = wait_continue
     self.max_redirects = max_redirects
     self.allow_redirects = allow_redirects
     self.charset = charset or 'utf-8'
     self.version = version
     self.decompress = decompress
     self.websocket_handler = websocket_handler
     self.source_address = source_address
     self.stream = stream
     self.verify = verify
     self.cert = cert
     if auth and not isinstance(auth, Auth):
         auth = HTTPBasicAuth(*auth)
     self.auth = auth
     self.url = full_url(url, params, method=self.method)
     self._set_proxy(proxies)
     self.key = RequestKey.create(self)
     self.headers = client.get_headers(self, headers)
     self.body = self._encode_body(data, files, json)
     self.unredirected_headers['host'] = self.key.netloc
     cookies = cookiejar_from_dict(client.cookies, cookies)
     if cookies:
         cookies.add_cookie_header(self)
示例#15
0
 def __init__(self, Parser, **kwargs):
     self.url = None
     self.headers = CIMultiDict()
     self.body = b''
     self.headers_complete = False
     self.message_complete = False
     self.parser = Parser(self, **kwargs)
     self.feed_data = self.parser.feed_data
示例#16
0
 def test_CacheControl(self):
     headers = CIMultiDict()
     c = CacheControl()
     self.assertFalse(c.private)
     self.assertFalse(c.maxage)
     c(headers)
     self.assertEqual(', '.join(headers.getall('cache-control')),
                      'no-cache')
     c = CacheControl(maxage=3600)
     c(headers)
     self.assertEqual(', '.join(headers.getall('cache-control')),
                      'max-age=3600, public')
     c = CacheControl(maxage=3600, private=True)
     c(headers)
     self.assertEqual(', '.join(headers.getall('cache-control')),
                      'max-age=3600, private')
     c = CacheControl(maxage=3600, must_revalidate=True)
     c(headers)
     self.assertEqual(', '.join(headers.getall('cache-control')),
                      'max-age=3600, public, must-revalidate')
     c = CacheControl(maxage=3600, proxy_revalidate=True)
     c(headers)
     self.assertEqual(', '.join(headers.getall('cache-control')),
                      'max-age=3600, public, proxy-revalidate')
     c = CacheControl(maxage=3600, proxy_revalidate=True,
                      nostore=True)
     c(headers)
     self.assertEqual(', '.join(headers.getall('cache-control')),
                      'no-store, no-cache, must-revalidate, max-age=0')
示例#17
0
    def __init__(self, *, status=200, reason=None, headers=None):
        self._body = None
        self._keep_alive = None
        self._chunked = False
        self._compression = False
        self._compression_force = False
        self._cookies = SimpleCookie()

        self._req = None
        self._payload_writer = None
        self._eof_sent = False

        if headers is not None:
            self._headers = CIMultiDict(headers)
        else:
            self._headers = CIMultiDict()

        self.set_status(status, reason)
示例#18
0
    def update_headers(self, headers):
        """Update request headers."""
        self.headers = CIMultiDict()
        if headers:
            if isinstance(headers, (dict, MultiDictProxy, MultiDict)):
                headers = headers.items()

            for key, value in headers:
                self.headers.add(key, value)
示例#19
0
    def __init__(self, transport, version, close, loop=None):
        super().__init__(transport, loop)

        self.version = version
        self.closing = close
        self.keepalive = None
        self.length = None
        self.headers = CIMultiDict()
        self.headers_sent = False
示例#20
0
 def __init__(self, status_code=200, content=None, response_headers=None,
              content_type=None, encoding=None, can_store_cookies=True):
     self.status_code = status_code
     self.encoding = encoding
     self.headers = CIMultiDict(response_headers or ())
     self.content = content
     self._cookies = None
     self._can_store_cookies = can_store_cookies
     if content_type is not None:
         self.content_type = content_type
示例#21
0
 def __init__(self, subtype='mixed', boundary=None):
     boundary = boundary if boundary is not None else uuid.uuid4().hex
     try:
         boundary.encode('us-ascii')
     except UnicodeEncodeError:
         raise ValueError('boundary should contains ASCII only chars')
     self.headers = CIMultiDict()
     self.headers[CONTENT_TYPE] = 'multipart/{}; boundary="{}"'.format(
         subtype, boundary
     )
     self.parts = []
示例#22
0
文件: message.py 项目: Eyepea/aiosip
    def from_raw_message(cls, raw_message):
        lines = raw_message.split(utils.EOL)
        first_line = lines.pop(0)
        headers = CIMultiDict()
        payload = ''
        reading_headers = True
        for line in lines:
            if reading_headers:
                if ': ' in line:
                    k, v = line.split(': ', 1)
                    if k in headers:
                        o = headers.setdefault(k, [])
                        if not isinstance(o, list):
                            o = [o]
                        o.append(v)
                        headers[k] = o
                    else:
                        headers[k] = v
                else:  # Finish to parse headers
                    reading_headers = False
            else: # @todo: use content length to read payload
                payload += line  # reading payload
        if payload == '':
            payload = None

        m = FIRST_LINE_PATTERN['response']['regex'].match(first_line)
        if m:
            d = m.groupdict()
            return Response(status_code=int(d['status_code']),
                            status_message=d['status_message'],
                            headers=headers,
                            payload=payload)
        else:
            m = FIRST_LINE_PATTERN['request']['regex'].match(first_line)
            if m:
                d = m.groupdict()
                return Request(method=d['method'],
                               headers=headers,
                               payload=payload)
            else:
                    raise ValueError('Not a SIP message')
示例#23
0
 def __init__(self, transport, version, close):
     self.transport = transport
     self._version = version
     self.closing = close
     self.keepalive = None
     self.chunked = False
     self.length = None
     self.headers = CIMultiDict()
     self.headers_sent = False
     self.output_length = 0
     self.headers_length = 0
     self._output_size = 0
示例#24
0
文件: aiohttp.py 项目: Tygs/tygs
    def maker(method, path, headers=None, *,
              version=HttpVersion(1, 1), closing=False,
              sslcontext=None,
              secure_proxy_ssl_header=None):
        if version < HttpVersion(1, 1):  # noqa
            closing = True

        if headers is None:
            headers = {}
        headers = CIMultiDict(headers)

        app = mock.Mock()
        app._debug = False
        app.on_response_prepare = Signal(app)

        if "HOST" not in headers:
            headers["HOST"] = "test.local"  # noqa

        message = RawRequestMessage(method, path, version, headers,
                                    [(k.encode('utf-8'), v.encode('utf-8'))
                                     for k, v in headers.items()],
                                    closing, False)
        payload = mock.Mock()
        transport = mock.Mock()

        def get_extra_info(key):  # noqa
            if key == 'sslcontext':
                return sslcontext
            else:
                return None

        transport.get_extra_info.side_effect = get_extra_info
        writer = mock.Mock()
        reader = mock.Mock()
        req = Request(app, message, payload,
                      transport, reader, writer,
                      secure_proxy_ssl_header=secure_proxy_ssl_header)

        return req
示例#25
0
    def __init__(self, *,
                 status: int=200,
                 reason: Optional[str]=None,
                 headers: Optional[LooseHeaders]=None) -> None:
        self._body = None
        self._keep_alive = None  # type: Optional[bool]
        self._chunked = False
        self._compression = False
        self._compression_force = None  # type: Optional[ContentCoding]
        self._cookies = SimpleCookie()

        self._req = None  # type: Optional[BaseRequest]
        self._payload_writer = None  # type: Optional[AbstractStreamWriter]
        self._eof_sent = False
        self._body_length = 0
        self._state = {}  # type: Dict[str, Any]

        if headers is not None:
            self._headers = CIMultiDict(headers)  # type: CIMultiDict[str]
        else:
            self._headers = CIMultiDict()  # type: CIMultiDict[str]

        self.set_status(status, reason)
示例#26
0
    def __init__(self, *,
                 status: int=200,
                 reason: Optional[str]=None,
                 headers: Optional[LooseHeaders]=None) -> None:
        self._body = None
        self._keep_alive = None
        self._chunked = False
        self._compression = False
        self._compression_force = None
        self._cookies = SimpleCookie()

        self._req = None
        self._payload_writer = None
        self._eof_sent = False
        self._body_length = 0
        self._state = {}  # type: Mapping

        if headers is not None:
            self._headers = CIMultiDict(headers)  # type: CIMultiDict
        else:
            self._headers = CIMultiDict()  # type: CIMultiDict

        self.set_status(status, reason)
示例#27
0
文件: message.py 项目: Eyepea/aiosip
    def __init__(self,
                 # from_uri,
                 # to_uri,
                 content_type=None,
                 headers=None,
                 payload=None):
        # self.from_uri = from_uri
        # self.to_uri = to_uri
        if headers:
            self.headers = headers
        else:
            self.headers = CIMultiDict()

        for direction in ('From', 'To', 'Contact'): # parse From, To, and Contact headers
            direction_attribute = '%s_details' % direction.lower()
            if direction in self.headers:
                if not hasattr(self, direction_attribute):
                    setattr(self,
                            direction_attribute,
                            Contact.from_header(self.headers[direction]))
            elif hasattr(self, direction_attribute):
                contact = getattr(self, direction_attribute)
                self.headers[direction] = str(contact)
            elif direction != 'Contact':
                raise(ValueError('You must have a "%s" header or details.' % direction))

            if content_type:
                self.headers['Content-Type'] = content_type
        self.payload = payload

        # Build the message
        if 'Via' not in self.headers:
            self.headers['Via'] = 'SIP/2.0/%(protocol)s '+'%s:%s;branch=%s' % (self.contact_details['uri']['host'],
                                                                               self.contact_details['uri']['port'],
                                                                               utils.gen_branch(10))
        if 'Max-Forwards' not in self.headers:
            self.headers['Max-Forwards'] = '70'
        if 'Call-ID' not in self.headers:
            self.headers['Call-ID'] = uuid.uuid4()
        if 'User-Agent' not in self.headers:
            self.headers['User-Agent'] = 'Python/{0[0]}.{0[1]}.{0[2]} aiosip/{1}'.format(
                sys.version_info, aiosip.__version__)
        if 'Content-Length' not in self.headers:
            payload_len = len(self.payload.encode()) if self.payload else 0
            self.headers['Content-Length'] = payload_len
示例#28
0
    def __init__(self, *, status=200, reason=None, headers=None):
        self._body = None
        self._keep_alive = None
        self._chunked = False
        self._chunk_size = None
        self._compression = False
        self._compression_force = False
        self._headers = CIMultiDict()
        self._cookies = http.cookies.SimpleCookie()
        self.set_status(status, reason)

        self._req = None
        self._resp_impl = None
        self._eof_sent = False

        if headers is not None:
            self._headers.extend(headers)
        if hdrs.CONTENT_TYPE not in self._headers:
            self._headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'
示例#29
0
    def __init__(self, *, status=200, reason=None, headers=None):
        self._body = None
        self._keep_alive = None
        self._chunked = False
        self._chunk_size = None
        self._compression = False
        self._compression_force = False
        self._headers = CIMultiDict()
        self._cookies = http.cookies.SimpleCookie()
        self.set_status(status, reason)

        self._req = None
        self._resp_impl = None
        self._eof_sent = False
        self._tcp_nodelay = True
        self._tcp_cork = False

        if headers is not None:
            self._headers.extend(headers)
示例#30
0
    def __init__(self, *, status=200, reason=None, headers=None):
        self._body = None
        self._keep_alive = None
        self._chunked = False
        self._chunk_size = None
        self._compression = False
        self._compression_force = False
        self._headers = CIMultiDict()
        self._cookies = SimpleCookie()
        self.set_status(status, reason)

        self._req = None
        self._resp_impl = None
        self._eof_sent = False

        self._task = None

        if headers is not None:
            # TODO: optimize CIMultiDict extending
            self._headers.extend(headers)
        self._headers.setdefault(hdrs.CONTENT_TYPE, 'application/octet-stream')
示例#31
0
def make_request(app, method, path, headers=CIMultiDict()):
    message = RawRequestMessage(method, path, HttpVersion11, headers,
                                [(k.encode('utf-8'), v.encode('utf-8'))
                                 for k, v in headers.items()],
                                False, False)
    return request_from_message(message, app)
示例#32
0
 def append_json(self, obj, headers=None):
     """Helper to append JSON part."""
     if not headers:
         headers = CIMultiDict()
     headers[CONTENT_TYPE] = 'application/json'
     return self.append(obj, headers)
示例#33
0
class MultipartWriter(object):
    """Multipart body writer."""

    #: Body part reader class for non multipart/* content types.
    part_writer_cls = BodyPartWriter

    def __init__(self, subtype='mixed', boundary=None):
        boundary = boundary if boundary is not None else uuid.uuid4().hex
        try:
            boundary.encode('us-ascii')
        except UnicodeEncodeError:
            raise ValueError('boundary should contains ASCII only chars')
        self.headers = CIMultiDict()
        self.headers[CONTENT_TYPE] = 'multipart/{}; boundary="{}"'.format(
            subtype, boundary)
        self.parts = []

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        pass

    def __iter__(self):
        return iter(self.parts)

    def __len__(self):
        return len(self.parts)

    @property
    def boundary(self):
        *_, params = parse_mimetype(self.headers.get(CONTENT_TYPE))
        return params['boundary'].encode('us-ascii')

    def append(self, obj, headers=None):
        """Adds a new body part to multipart writer."""
        if isinstance(obj, self.part_writer_cls):
            if headers:
                obj.headers.update(headers)
            self.parts.append(obj)
        else:
            if not headers:
                headers = CIMultiDict()
            self.parts.append(self.part_writer_cls(obj, headers))
        return self.parts[-1]

    def append_json(self, obj, headers=None):
        """Helper to append JSON part."""
        if not headers:
            headers = CIMultiDict()
        headers[CONTENT_TYPE] = 'application/json'
        return self.append(obj, headers)

    def append_form(self, obj, headers=None):
        """Helper to append form urlencoded part."""
        if not headers:
            headers = CIMultiDict()
        headers[CONTENT_TYPE] = 'application/x-www-form-urlencoded'
        assert isinstance(obj, (Sequence, Mapping))
        return self.append(obj, headers)

    def serialize(self):
        """Yields multipart byte chunks."""
        if not self.parts:
            yield b''
            return

        for part in self.parts:
            yield b'--' + self.boundary + b'\r\n'
            yield from part.serialize()
        else:
            yield b'--' + self.boundary + b'--\r\n'

        yield b''
示例#34
0
def test_host_by_host_header() -> None:
    req = make_mocked_request("GET", "/", headers=CIMultiDict({"Host": "example.com"}))
    assert req.host == "example.com"
示例#35
0
    def parse_headers(self, lines):
        """Parses RFC 5322 headers from a stream.

        Line continuations are supported. Returns list of header name
        and value pairs. Header name is in upper case.
        """
        headers = CIMultiDict()
        raw_headers = []

        lines_idx = 1
        line = lines[1]

        while line:
            header_length = len(line)

            # Parse initial header name : value pair.
            try:
                bname, bvalue = line.split(b':', 1)
            except ValueError:
                raise errors.InvalidHeader(line) from None

            bname = bname.strip(b' \t').upper()
            if HDRRE.search(bname):
                raise errors.InvalidHeader(bname)

            # next line
            lines_idx += 1
            line = lines[lines_idx]

            # consume continuation lines
            continuation = line and line[0] in (32, 9)  # (' ', '\t')

            if continuation:
                bvalue = [bvalue]
                while continuation:
                    header_length += len(line)
                    if header_length > self.max_field_size:
                        raise errors.LineTooLong(
                            'request header field {}'.format(
                                bname.decode("utf8", "xmlcharrefreplace")),
                            self.max_field_size)
                    bvalue.append(line)

                    # next line
                    lines_idx += 1
                    line = lines[lines_idx]
                    continuation = line[0] in (32, 9)  # (' ', '\t')
                bvalue = b'\r\n'.join(bvalue)
            else:
                if header_length > self.max_field_size:
                    raise errors.LineTooLong(
                        'request header field {}'.format(
                            bname.decode("utf8", "xmlcharrefreplace")),
                        self.max_field_size)

            bvalue = bvalue.strip()

            name = istr(bname.decode('utf-8', 'surrogateescape'))
            value = bvalue.decode('utf-8', 'surrogateescape')

            headers.add(name, value)
            raw_headers.append((bname, bvalue))

        close_conn = None
        encoding = None
        upgrade = False
        chunked = False

        # keep-alive
        conn = headers.get(hdrs.CONNECTION)
        if conn:
            v = conn.lower()
            if v == 'close':
                close_conn = True
            elif v == 'keep-alive':
                close_conn = False
            elif v == 'upgrade':
                upgrade = True

        # encoding
        enc = headers.get(hdrs.CONTENT_ENCODING)
        if enc:
            enc = enc.lower()
            if enc in ('gzip', 'deflate'):
                encoding = enc

        # chunking
        te = headers.get(hdrs.TRANSFER_ENCODING)
        if te and 'chunked' in te.lower():
            chunked = True

        return headers, raw_headers, close_conn, encoding, upgrade, chunked
示例#36
0
class ClientRequest:
    GET_METHODS = {
        hdrs.METH_GET,
        hdrs.METH_HEAD,
        hdrs.METH_OPTIONS,
        hdrs.METH_TRACE,
    }
    POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    ALL_METHODS = GET_METHODS.union(POST_METHODS).union({hdrs.METH_DELETE})

    DEFAULT_HEADERS = {
        hdrs.ACCEPT: '*/*',
        hdrs.ACCEPT_ENCODING: 'gzip, deflate',
    }

    body = b''
    auth = None
    response = None
    response_class = None

    _writer = None  # async task for streaming data
    _continue = None  # waiter future for '100 Continue' response

    # N.B.
    # Adding __del__ method with self._writer closing doesn't make sense
    # because _writer is instance method, thus it keeps a reference to self.
    # Until writer has finished finalizer will not be called.

    def __init__(self,
                 method,
                 url,
                 *,
                 params=None,
                 headers=None,
                 skip_auto_headers=frozenset(),
                 data=None,
                 cookies=None,
                 auth=None,
                 version=http.HttpVersion11,
                 compress=None,
                 chunked=None,
                 expect100=False,
                 loop=None,
                 response_class=None,
                 proxy=None,
                 proxy_auth=None,
                 timer=None,
                 session=None,
                 ssl=None,
                 proxy_headers=None,
                 traces=None):

        if loop is None:
            loop = asyncio.get_event_loop()

        assert isinstance(url, URL), url
        assert isinstance(proxy, (URL, type(None))), proxy
        self._session = session
        if params:
            q = MultiDict(url.query)
            url2 = url.with_query(params)
            q.extend(url2.query)
            url = url.with_query(q)
        self.original_url = url
        self.url = url.with_fragment(None)
        self.method = method.upper()
        self.chunked = chunked
        self.compress = compress
        self.loop = loop
        self.length = None
        self.response_class = response_class or ClientResponse
        self._timer = timer if timer is not None else TimerNoop()
        self._ssl = ssl

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        self.update_version(version)
        self.update_host(url)
        self.update_headers(headers)
        self.update_auto_headers(skip_auto_headers)
        self.update_cookies(cookies)
        self.update_content_encoding(data)
        self.update_auth(auth)
        self.update_proxy(proxy, proxy_auth, proxy_headers)

        self.update_body_from_data(data)
        if data or self.method not in self.GET_METHODS:
            self.update_transfer_encoding()
        self.update_expect_continue(expect100)
        if traces is None:
            traces = []
        self._traces = traces

    def is_ssl(self):
        return self.url.scheme in ('https', 'wss')

    @property
    def ssl(self):
        return self._ssl

    @property
    def connection_key(self):
        proxy_headers = self.proxy_headers
        if proxy_headers:
            h = hash(tuple((k, v) for k, v in proxy_headers.items()))
        else:
            h = None
        return ConnectionKey(self.host, self.port, self.is_ssl(), self.ssl,
                             self.proxy, self.proxy_auth, h)

    @property
    def host(self):
        return self.url.host

    @property
    def port(self):
        return self.url.port

    @property
    def request_info(self):
        return RequestInfo(self.url, self.method, self.headers,
                           self.original_url)

    def update_host(self, url):
        """Update destination host, port and connection type (ssl)."""
        # get host/port
        if not url.host:
            raise InvalidURL(url)

        # basic auth info
        username, password = url.user, url.password
        if username:
            self.auth = helpers.BasicAuth(username, password or '')

    def update_version(self, version):
        """Convert request version to two elements tuple.

        parser HTTP version '1.1' => (1, 1)
        """
        if isinstance(version, str):
            v = [l.strip() for l in version.split('.', 1)]
            try:
                version = int(v[0]), int(v[1])
            except ValueError:
                raise ValueError(
                    'Can not parse http version number: {}'.format(
                        version)) from None
        self.version = version

    def update_headers(self, headers):
        """Update request headers."""
        self.headers = CIMultiDict()
        if headers:
            if isinstance(headers, (dict, MultiDictProxy, MultiDict)):
                headers = headers.items()

            for key, value in headers:
                self.headers.add(key, value)

    def update_auto_headers(self, skip_auto_headers):
        self.skip_auto_headers = CIMultiDict(
            (hdr, None) for hdr in sorted(skip_auto_headers))
        used_headers = self.headers.copy()
        used_headers.extend(self.skip_auto_headers)

        for hdr, val in self.DEFAULT_HEADERS.items():
            if hdr not in used_headers:
                self.headers.add(hdr, val)

        # add host
        if hdrs.HOST not in used_headers:
            netloc = self.url.raw_host
            if helpers.is_ipv6_address(netloc):
                netloc = '[{}]'.format(netloc)
            if not self.url.is_default_port():
                netloc += ':' + str(self.url.port)
            self.headers[hdrs.HOST] = netloc

        if hdrs.USER_AGENT not in used_headers:
            self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE

    def update_cookies(self, cookies):
        """Update request cookies header."""
        if not cookies:
            return

        c = SimpleCookie()
        if hdrs.COOKIE in self.headers:
            c.load(self.headers.get(hdrs.COOKIE, ''))
            del self.headers[hdrs.COOKIE]

        for name, value in cookies.items():
            if isinstance(value, Morsel):
                # Preserve coded_value
                mrsl_val = value.get(value.key, Morsel())
                mrsl_val.set(value.key, value.value, value.coded_value)
                c[name] = mrsl_val
            else:
                c[name] = value

        self.headers[hdrs.COOKIE] = c.output(header='', sep=';').strip()

    def update_content_encoding(self, data):
        """Set request content encoding."""
        if not data:
            return

        enc = self.headers.get(hdrs.CONTENT_ENCODING, '').lower()
        if enc:
            if self.compress:
                raise ValueError('compress can not be set '
                                 'if Content-Encoding header is set')
        elif self.compress:
            if not isinstance(self.compress, str):
                self.compress = 'deflate'
            self.headers[hdrs.CONTENT_ENCODING] = self.compress
            self.chunked = True  # enable chunked, no need to deal with length

    def update_transfer_encoding(self):
        """Analyze transfer-encoding header."""
        te = self.headers.get(hdrs.TRANSFER_ENCODING, '').lower()

        if 'chunked' in te:
            if self.chunked:
                raise ValueError(
                    'chunked can not be set '
                    'if "Transfer-Encoding: chunked" header is set')

        elif self.chunked:
            if hdrs.CONTENT_LENGTH in self.headers:
                raise ValueError('chunked can not be set '
                                 'if Content-Length header is set')

            self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'
        else:
            if hdrs.CONTENT_LENGTH not in self.headers:
                self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_auth(self, auth):
        """Set basic auth."""
        if auth is None:
            auth = self.auth
        if auth is None:
            return

        if not isinstance(auth, helpers.BasicAuth):
            raise TypeError('BasicAuth() tuple is required instead')

        self.headers[hdrs.AUTHORIZATION] = auth.encode()

    def update_body_from_data(self, body):
        if not body:
            return

        # FormData
        if isinstance(body, FormData):
            body = body()

        try:
            body = payload.PAYLOAD_REGISTRY.get(body, disposition=None)
        except payload.LookupError:
            body = FormData(body)()

        self.body = body

        # enable chunked encoding if needed
        if not self.chunked:
            if hdrs.CONTENT_LENGTH not in self.headers:
                size = body.size
                if size is None:
                    self.chunked = True
                else:
                    if hdrs.CONTENT_LENGTH not in self.headers:
                        self.headers[hdrs.CONTENT_LENGTH] = str(size)

        # set content-type
        if (hdrs.CONTENT_TYPE not in self.headers
                and hdrs.CONTENT_TYPE not in self.skip_auto_headers):
            self.headers[hdrs.CONTENT_TYPE] = body.content_type

        # copy payload headers
        if body.headers:
            for (key, value) in body.headers.items():
                if key not in self.headers:
                    self.headers[key] = value

    def update_expect_continue(self, expect=False):
        if expect:
            self.headers[hdrs.EXPECT] = '100-continue'
        elif self.headers.get(hdrs.EXPECT, '').lower() == '100-continue':
            expect = True

        if expect:
            self._continue = self.loop.create_future()

    def update_proxy(self, proxy, proxy_auth, proxy_headers):
        if proxy and not proxy.scheme == 'http':
            raise ValueError("Only http proxies are supported")
        if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
            raise ValueError("proxy_auth must be None or BasicAuth() tuple")
        self.proxy = proxy
        self.proxy_auth = proxy_auth
        self.proxy_headers = proxy_headers

    def keep_alive(self):
        if self.version < HttpVersion10:
            # keep alive not supported at all
            return False
        if self.version == HttpVersion10:
            if self.headers.get(hdrs.CONNECTION) == 'keep-alive':
                return True
            else:  # no headers means we close for Http 1.0
                return False
        elif self.headers.get(hdrs.CONNECTION) == 'close':
            return False

        return True

    async def write_bytes(self, writer, conn):
        """Support coroutines that yields bytes objects."""
        # 100 response
        if self._continue is not None:
            await writer.drain()
            await self._continue

        try:
            if isinstance(self.body, payload.Payload):
                await self.body.write(writer)
            else:
                if isinstance(self.body, (bytes, bytearray)):
                    self.body = (self.body, )

                for chunk in self.body:
                    await writer.write(chunk)

            await writer.write_eof()
        except OSError as exc:
            new_exc = ClientOSError(
                exc.errno, 'Can not write request body for %s' % self.url)
            new_exc.__context__ = exc
            new_exc.__cause__ = exc
            conn.protocol.set_exception(new_exc)
        except asyncio.CancelledError as exc:
            if not conn.closed:
                conn.protocol.set_exception(exc)
        except Exception as exc:
            conn.protocol.set_exception(exc)
        finally:
            self._writer = None

    async def send(self, conn):
        # Specify request target:
        # - CONNECT request must send authority form URI
        # - not CONNECT proxy must send absolute form URI
        # - most common is origin form URI
        if self.method == hdrs.METH_CONNECT:
            path = '{}:{}'.format(self.url.raw_host, self.url.port)
        elif self.proxy and not self.is_ssl():
            path = str(self.url)
        else:
            path = self.url.raw_path
            if self.url.raw_query_string:
                path += '?' + self.url.raw_query_string

        writer = StreamWriter(conn.protocol,
                              self.loop,
                              on_chunk_sent=self._on_chunk_request_sent)

        if self.compress:
            writer.enable_compression(self.compress)

        if self.chunked is not None:
            writer.enable_chunking()

        # set default content-type
        if (self.method in self.POST_METHODS
                and hdrs.CONTENT_TYPE not in self.skip_auto_headers
                and hdrs.CONTENT_TYPE not in self.headers):
            self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'

        # set the connection header
        connection = self.headers.get(hdrs.CONNECTION)
        if not connection:
            if self.keep_alive():
                if self.version == HttpVersion10:
                    connection = 'keep-alive'
            else:
                if self.version == HttpVersion11:
                    connection = 'close'

        if connection is not None:
            self.headers[hdrs.CONNECTION] = connection

        # status + headers
        status_line = '{0} {1} HTTP/{2[0]}.{2[1]}'.format(
            self.method, path, self.version)
        await writer.write_headers(status_line, self.headers)

        self._writer = self.loop.create_task(self.write_bytes(writer, conn))

        self.response = self.response_class(self.method,
                                            self.original_url,
                                            writer=self._writer,
                                            continue100=self._continue,
                                            timer=self._timer,
                                            request_info=self.request_info,
                                            traces=self._traces,
                                            loop=self.loop,
                                            session=self._session)
        return self.response

    async def close(self):
        if self._writer is not None:
            try:
                await self._writer
            finally:
                self._writer = None

    def terminate(self):
        if self._writer is not None:
            if not self.loop.is_closed():
                self._writer.cancel()
            self._writer = None

    async def _on_chunk_request_sent(self, chunk):
        for trace in self._traces:
            await trace.send_request_chunk_sent(chunk)
示例#37
0
def test_clone_headers_dict() -> None:
    req = make_mocked_request("GET", "/path", headers={"A": "B"})
    req2 = req.clone(headers={"B": "C"})
    assert req2.headers == CIMultiDict({"B": "C"})
    assert req2.raw_headers == ((b"B", b"C"),)
示例#38
0
class ClientRequest:

    GET_METHODS = {hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS}
    POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    ALL_METHODS = GET_METHODS.union(POST_METHODS).union(
        {hdrs.METH_DELETE, hdrs.METH_TRACE})

    DEFAULT_HEADERS = {
        hdrs.ACCEPT: '*/*',
        hdrs.ACCEPT_ENCODING: 'gzip, deflate',
    }

    SERVER_SOFTWARE = HttpMessage.SERVER_SOFTWARE

    body = b''
    auth = None
    response = None
    response_class = None

    _writer = None  # async task for streaming data
    _continue = None  # waiter future for '100 Continue' response

    # N.B.
    # Adding __del__ method with self._writer closing doesn't make sense
    # because _writer is instance method, thus it keeps a reference to self.
    # Until writer has finished finalizer will not be called.

    def __init__(self, method, url, *,
                 params=None, headers=None, skip_auto_headers=frozenset(),
                 data=None, cookies=None,
                 auth=None, encoding='utf-8',
                 version=aiohttp.HttpVersion11, compress=None,
                 chunked=None, expect100=False,
                 loop=None, response_class=None):

        if loop is None:
            loop = asyncio.get_event_loop()

        self.url = url
        self.method = method.upper()
        self.encoding = encoding
        self.chunked = chunked
        self.compress = compress
        self.loop = loop
        self.response_class = response_class or ClientResponse

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        self.update_version(version)
        self.update_host(url)
        self.update_path(params)
        self.update_headers(headers)
        self.update_auto_headers(skip_auto_headers)
        self.update_cookies(cookies)
        self.update_content_encoding(data)
        self.update_auth(auth)

        self.update_body_from_data(data, skip_auto_headers)
        self.update_transfer_encoding()
        self.update_expect_continue(expect100)

    def update_host(self, url):
        """Update destination host, port and connection type (ssl)."""
        url_parsed = urllib.parse.urlsplit(url)

        # check for network location part
        netloc = url_parsed.netloc
        if not netloc:
            raise ValueError('Host could not be detected.')

        # get host/port
        host = url_parsed.hostname
        if not host:
            raise ValueError('Host could not be detected.')

        try:
            port = url_parsed.port
        except ValueError:
            raise ValueError(
                'Port number could not be converted.') from None

        # check domain idna encoding
        try:
            netloc = netloc.encode('idna').decode('utf-8')
            host = host.encode('idna').decode('utf-8')
        except UnicodeError:
            raise ValueError('URL has an invalid label.')

        # basic auth info
        username, password = url_parsed.username, url_parsed.password
        if username:
            self.auth = helpers.BasicAuth(username, password or '')
            netloc = netloc.split('@', 1)[1]

        # Record entire netloc for usage in host header
        self.netloc = netloc

        scheme = url_parsed.scheme
        self.ssl = scheme in ('https', 'wss')

        # set port number if it isn't already set
        if not port:
            if self.ssl:
                port = HTTPS_PORT
            else:
                port = HTTP_PORT

        self.host, self.port, self.scheme = host, port, scheme

    def update_version(self, version):
        """Convert request version to two elements tuple.

        parser HTTP version '1.1' => (1, 1)
        """
        if isinstance(version, str):
            v = [l.strip() for l in version.split('.', 1)]
            try:
                version = int(v[0]), int(v[1])
            except ValueError:
                raise ValueError(
                    'Can not parse http version number: {}'
                    .format(version)) from None
        self.version = version

    def update_path(self, params):
        """Build path."""
        # extract path
        scheme, netloc, path, query, fragment = urllib.parse.urlsplit(self.url)
        if not path:
            path = '/'

        if isinstance(params, collections.Mapping):
            params = list(params.items())

        if params:
            if not isinstance(params, str):
                params = urllib.parse.urlencode(params)
            if query:
                query = '%s&%s' % (query, params)
            else:
                query = params

        self.path = urllib.parse.urlunsplit(('', '', helpers.requote_uri(path),
                                             query, fragment))
        self.url = urllib.parse.urlunsplit(
            (scheme, netloc, self.path, '', ''))

    def update_headers(self, headers):
        """Update request headers."""
        self.headers = CIMultiDict()
        if headers:
            if isinstance(headers, dict):
                headers = headers.items()
            elif isinstance(headers, (MultiDictProxy, MultiDict)):
                headers = headers.items()

            for key, value in headers:
                self.headers.add(key, value)

    def update_auto_headers(self, skip_auto_headers):
        self.skip_auto_headers = skip_auto_headers
        used_headers = set(self.headers) | skip_auto_headers

        for hdr, val in self.DEFAULT_HEADERS.items():
            if hdr not in used_headers:
                self.headers.add(hdr, val)

        # add host
        if hdrs.HOST not in used_headers:
            self.headers[hdrs.HOST] = self.netloc

        if hdrs.USER_AGENT not in used_headers:
            self.headers[hdrs.USER_AGENT] = self.SERVER_SOFTWARE

    def update_cookies(self, cookies):
        """Update request cookies header."""
        if not cookies:
            return

        c = http.cookies.SimpleCookie()
        if hdrs.COOKIE in self.headers:
            c.load(self.headers.get(hdrs.COOKIE, ''))
            del self.headers[hdrs.COOKIE]

        if isinstance(cookies, dict):
            cookies = cookies.items()

        for name, value in cookies:
            if isinstance(value, http.cookies.Morsel):
                c[value.key] = value.value
            else:
                c[name] = value

        self.headers[hdrs.COOKIE] = c.output(header='', sep=';').strip()

    def update_content_encoding(self, data):
        """Set request content encoding."""
        if not data:
            return

        enc = self.headers.get(hdrs.CONTENT_ENCODING, '').lower()
        if enc:
            if self.compress is not False:
                self.compress = enc
                # enable chunked, no need to deal with length
                self.chunked = True
        elif self.compress:
            if not isinstance(self.compress, str):
                self.compress = 'deflate'
            self.headers[hdrs.CONTENT_ENCODING] = self.compress
            self.chunked = True  # enable chunked, no need to deal with length

    def update_auth(self, auth):
        """Set basic auth."""
        if auth is None:
            auth = self.auth
        if auth is None:
            return

        if not isinstance(auth, helpers.BasicAuth):
            warnings.warn(
                'BasicAuth() tuple is required instead ', DeprecationWarning)
            auth = helpers.BasicAuth(*auth)

        self.headers[hdrs.AUTHORIZATION] = auth.encode()

    def update_body_from_data(self, data, skip_auto_headers):
        if not data:
            return

        if isinstance(data, str):
            data = data.encode(self.encoding)

        if isinstance(data, (bytes, bytearray)):
            self.body = data
            if (hdrs.CONTENT_TYPE not in self.headers and
                    hdrs.CONTENT_TYPE not in skip_auto_headers):
                self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'
            if hdrs.CONTENT_LENGTH not in self.headers and not self.chunked:
                self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

        elif isinstance(data, (asyncio.StreamReader, streams.DataQueue)):
            self.body = data

        elif asyncio.iscoroutine(data):
            self.body = data
            if (hdrs.CONTENT_LENGTH not in self.headers and
                    self.chunked is None):
                self.chunked = True

        elif isinstance(data, io.IOBase):
            assert not isinstance(data, io.StringIO), \
                'attempt to send text data instead of binary'
            self.body = data
            if not self.chunked and isinstance(data, io.BytesIO):
                # Not chunking if content-length can be determined
                size = len(data.getbuffer())
                self.headers[hdrs.CONTENT_LENGTH] = str(size)
                self.chunked = False
            elif not self.chunked and isinstance(data, io.BufferedReader):
                # Not chunking if content-length can be determined
                try:
                    size = os.fstat(data.fileno()).st_size - data.tell()
                    self.headers[hdrs.CONTENT_LENGTH] = str(size)
                    self.chunked = False
                except OSError:
                    # data.fileno() is not supported, e.g.
                    # io.BufferedReader(io.BytesIO(b'data'))
                    self.chunked = True
            else:
                self.chunked = True

            if hasattr(data, 'mode'):
                if data.mode == 'r':
                    raise ValueError('file {!r} should be open in binary mode'
                                     ''.format(data))
            if (hdrs.CONTENT_TYPE not in self.headers and
                hdrs.CONTENT_TYPE not in skip_auto_headers and
                    hasattr(data, 'name')):
                mime = mimetypes.guess_type(data.name)[0]
                mime = 'application/octet-stream' if mime is None else mime
                self.headers[hdrs.CONTENT_TYPE] = mime

        elif isinstance(data, MultipartWriter):
            self.body = data.serialize()
            self.headers.update(data.headers)
            self.chunked = self.chunked or 8192

        else:
            if not isinstance(data, helpers.FormData):
                data = helpers.FormData(data)

            self.body = data(self.encoding)

            if (hdrs.CONTENT_TYPE not in self.headers and
                    hdrs.CONTENT_TYPE not in skip_auto_headers):
                self.headers[hdrs.CONTENT_TYPE] = data.content_type

            if data.is_multipart:
                self.chunked = self.chunked or 8192
            else:
                if (hdrs.CONTENT_LENGTH not in self.headers and
                        not self.chunked):
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_transfer_encoding(self):
        """Analyze transfer-encoding header."""
        te = self.headers.get(hdrs.TRANSFER_ENCODING, '').lower()

        if self.chunked:
            if hdrs.CONTENT_LENGTH in self.headers:
                del self.headers[hdrs.CONTENT_LENGTH]
            if 'chunked' not in te:
                self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'

            self.chunked = self.chunked if type(self.chunked) is int else 8192
        else:
            if 'chunked' in te:
                self.chunked = 8192
            else:
                self.chunked = None
                if hdrs.CONTENT_LENGTH not in self.headers:
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_expect_continue(self, expect=False):
        if expect:
            self.headers[hdrs.EXPECT] = '100-continue'
        elif self.headers.get(hdrs.EXPECT, '').lower() == '100-continue':
            expect = True

        if expect:
            self._continue = helpers.create_future(self.loop)

    @asyncio.coroutine
    def write_bytes(self, request, reader):
        """Support coroutines that yields bytes objects."""
        # 100 response
        if self._continue is not None:
            yield from self._continue

        try:
            if asyncio.iscoroutine(self.body):
                request.transport.set_tcp_nodelay(True)
                exc = None
                value = None
                stream = self.body

                while True:
                    try:
                        if exc is not None:
                            result = stream.throw(exc)
                        else:
                            result = stream.send(value)
                    except StopIteration as exc:
                        if isinstance(exc.value, bytes):
                            yield from request.write(exc.value, drain=True)
                        break
                    except:
                        self.response.close()
                        raise

                    if isinstance(result, asyncio.Future):
                        exc = None
                        value = None
                        try:
                            value = yield result
                        except Exception as err:
                            exc = err
                    elif isinstance(result, (bytes, bytearray)):
                        yield from request.write(result, drain=True)
                        value = None
                    else:
                        raise ValueError(
                            'Bytes object is expected, got: %s.' %
                            type(result))

            elif isinstance(self.body, asyncio.StreamReader):
                request.transport.set_tcp_nodelay(True)
                chunk = yield from self.body.read(streams.DEFAULT_LIMIT)
                while chunk:
                    yield from request.write(chunk, drain=True)
                    chunk = yield from self.body.read(streams.DEFAULT_LIMIT)

            elif isinstance(self.body, streams.DataQueue):
                request.transport.set_tcp_nodelay(True)
                while True:
                    try:
                        chunk = yield from self.body.read()
                        if chunk is EOF_MARKER:
                            break
                        yield from request.write(chunk, drain=True)
                    except streams.EofStream:
                        break

            elif isinstance(self.body, io.IOBase):
                chunk = self.body.read(self.chunked)
                while chunk:
                    request.write(chunk)
                    chunk = self.body.read(self.chunked)
                request.transport.set_tcp_nodelay(True)

            else:
                if isinstance(self.body, (bytes, bytearray)):
                    self.body = (self.body,)

                for chunk in self.body:
                    request.write(chunk)
                request.transport.set_tcp_nodelay(True)

        except Exception as exc:
            new_exc = aiohttp.ClientRequestError(
                'Can not write request body for %s' % self.url)
            new_exc.__context__ = exc
            new_exc.__cause__ = exc
            reader.set_exception(new_exc)
        else:
            assert request.transport.tcp_nodelay
            try:
                ret = request.write_eof()
                # NB: in asyncio 3.4.1+ StreamWriter.drain() is coroutine
                # see bug #170
                if (asyncio.iscoroutine(ret) or
                        isinstance(ret, asyncio.Future)):
                    yield from ret
            except Exception as exc:
                new_exc = aiohttp.ClientRequestError(
                    'Can not write request body for %s' % self.url)
                new_exc.__context__ = exc
                new_exc.__cause__ = exc
                reader.set_exception(new_exc)

        self._writer = None

    def send(self, writer, reader):
        writer.set_tcp_cork(True)
        request = aiohttp.Request(writer, self.method, self.path, self.version)

        if self.compress:
            request.add_compression_filter(self.compress)

        if self.chunked is not None:
            request.enable_chunked_encoding()
            request.add_chunking_filter(self.chunked)

        # set default content-type
        if (self.method in self.POST_METHODS and
                hdrs.CONTENT_TYPE not in self.skip_auto_headers and
                hdrs.CONTENT_TYPE not in self.headers):
            self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'

        for k, value in self.headers.items():
            request.add_header(k, value)
        request.send_headers()

        self._writer = helpers.ensure_future(
            self.write_bytes(request, reader), loop=self.loop)

        self.response = self.response_class(
            self.method, self.url, self.host,
            writer=self._writer, continue100=self._continue)
        self.response._post_init(self.loop)
        return self.response

    @asyncio.coroutine
    def close(self):
        if self._writer is not None:
            try:
                yield from self._writer
            finally:
                self._writer = None

    def terminate(self):
        if self._writer is not None:
            if hasattr(self.loop, 'is_closed'):
                if not self.loop.is_closed():
                    self._writer.cancel()
            else:
                self._writer.cancel()
            self._writer = None
示例#39
0
def test_single_forwarded_header_long_quoted_string() -> None:
    header = 'for="' + "\\\\" * 5000 + '"'
    req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header}))
    assert req.forwarded[0]["for"] == "\\" * 5000
 async def send(self, request, **config):
     response = await super(AiohttpTestTransport, self).send(request, **config)
     if not isinstance(response.headers, CIMultiDictProxy):
         response.headers = CIMultiDictProxy(CIMultiDict(response.internal_response.headers))
         response.content_type = response.headers.get("content-type")
     return response
示例#41
0
def make_mocked_request(method,
                        path,
                        headers=None,
                        *,
                        match_info=sentinel,
                        version=HttpVersion(1, 1),
                        closing=False,
                        app=None,
                        writer=sentinel,
                        protocol=sentinel,
                        transport=sentinel,
                        payload=sentinel,
                        sslcontext=None,
                        client_max_size=1024**2,
                        loop=...):
    """Creates mocked web.Request testing purposes.

    Useful in unit tests, when spinning full web server is overkill or
    specific conditions and errors are hard to trigger.

    """

    task = mock.Mock()
    if loop is ...:
        loop = mock.Mock()
        loop.create_future.return_value = ()

    if version < HttpVersion(1, 1):
        closing = True

    if headers:
        headers = CIMultiDict(headers)
        raw_hdrs = tuple(
            (k.encode('utf-8'), v.encode('utf-8')) for k, v in headers.items())
    else:
        headers = CIMultiDict()
        raw_hdrs = ()

    chunked = 'chunked' in headers.get(hdrs.TRANSFER_ENCODING, '').lower()

    message = RawRequestMessage(method, path, version, headers, raw_hdrs,
                                closing, False, False, chunked, URL(path))
    if app is None:
        app = _create_app_mock()

    if protocol is sentinel:
        protocol = mock.Mock()

    if transport is sentinel:
        transport = _create_transport(sslcontext)

    if writer is sentinel:
        writer = mock.Mock()
        writer.write_headers = make_mocked_coro(None)
        writer.write = make_mocked_coro(None)
        writer.write_eof = make_mocked_coro(None)
        writer.drain = make_mocked_coro(None)
        writer.transport = transport

    protocol.transport = transport
    protocol.writer = writer

    if payload is sentinel:
        payload = mock.Mock()

    req = Request(message,
                  payload,
                  protocol,
                  writer,
                  task,
                  loop,
                  client_max_size=client_max_size)

    match_info = UrlMappingMatchInfo(
        {} if match_info is sentinel else match_info, mock.Mock())
    match_info.add_app(app)
    req._match_info = match_info

    return req
示例#42
0
 def _make_headers(headers: t.Optional[LooseHeaders]) -> CIMultiDict:
     headers = CIMultiDict(headers or {})
     return headers
示例#43
0
    def __init__(self, *, connector=None, loop=None, cookies=None,
                 headers=None, skip_auto_headers=None,
                 auth=None, json_serialize=json.dumps,
                 request_class=ClientRequest, response_class=ClientResponse,
                 ws_response_class=ClientWebSocketResponse,
                 version=http.HttpVersion11,
                 cookie_jar=None, connector_owner=True, raise_for_status=False,
                 read_timeout=sentinel, conn_timeout=None,
                 auto_decompress=True, trust_env=False,
                 trace_configs=None):

        implicit_loop = False
        if loop is None:
            if connector is not None:
                loop = connector._loop
            else:
                implicit_loop = True
                loop = asyncio.get_event_loop()

        if connector is None:
            connector = TCPConnector(loop=loop)

        if connector._loop is not loop:
            raise RuntimeError(
                "Session and connector has to use same event loop")

        self._loop = loop

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        if implicit_loop and not loop.is_running():
            warnings.warn("Creating a client session outside of coroutine is "
                          "a very dangerous idea",
                          stacklevel=2)
            context = {'client_session': self,
                       'message': 'Creating a client session outside '
                       'of coroutine'}
            if self._source_traceback is not None:
                context['source_traceback'] = self._source_traceback
            loop.call_exception_handler(context)

        if cookie_jar is None:
            cookie_jar = CookieJar(loop=loop)
        self._cookie_jar = cookie_jar

        if cookies is not None:
            self._cookie_jar.update_cookies(cookies)

        self._connector = connector
        self._connector_owner = connector_owner
        self._default_auth = auth
        self._version = version
        self._json_serialize = json_serialize
        self._read_timeout = (read_timeout if read_timeout is not sentinel
                              else DEFAULT_TIMEOUT)
        self._conn_timeout = conn_timeout
        self._raise_for_status = raise_for_status
        self._auto_decompress = auto_decompress
        self._trust_env = trust_env

        # Convert to list of tuples
        if headers:
            headers = CIMultiDict(headers)
        else:
            headers = CIMultiDict()
        self._default_headers = headers
        if skip_auto_headers is not None:
            self._skip_auto_headers = frozenset([istr(i)
                                                 for i in skip_auto_headers])
        else:
            self._skip_auto_headers = frozenset()

        self._request_class = request_class
        self._response_class = response_class
        self._ws_response_class = ws_response_class

        self._trace_configs = trace_configs or []
        for trace_config in self._trace_configs:
            trace_config.freeze()
示例#44
0
    async def _ws_connect(self, url, *,
                          protocols=(),
                          timeout=10.0,
                          receive_timeout=None,
                          autoclose=True,
                          autoping=True,
                          heartbeat=None,
                          auth=None,
                          origin=None,
                          headers=None,
                          proxy=None,
                          proxy_auth=None,
                          ssl=None,
                          verify_ssl=None,
                          fingerprint=None,
                          ssl_context=None,
                          proxy_headers=None,
                          compress=0):

        if headers is None:
            headers = CIMultiDict()

        default_headers = {
            hdrs.UPGRADE: hdrs.WEBSOCKET,
            hdrs.CONNECTION: hdrs.UPGRADE,
            hdrs.SEC_WEBSOCKET_VERSION: '13',
        }

        for key, value in default_headers.items():
            if key not in headers:
                headers[key] = value

        sec_key = base64.b64encode(os.urandom(16))
        headers[hdrs.SEC_WEBSOCKET_KEY] = sec_key.decode()

        if protocols:
            headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols)
        if origin is not None:
            headers[hdrs.ORIGIN] = origin
        if compress:
            extstr = ws_ext_gen(compress=compress)
            headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr

        ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint)

        # send request
        resp = await self.get(url, headers=headers,
                              read_until_eof=False,
                              auth=auth,
                              proxy=proxy,
                              proxy_auth=proxy_auth,
                              ssl=ssl,
                              proxy_headers=proxy_headers)

        try:
            # check handshake
            if resp.status != 101:
                raise WSServerHandshakeError(
                    resp.request_info,
                    resp.history,
                    message='Invalid response status',
                    code=resp.status,
                    headers=resp.headers)

            if resp.headers.get(hdrs.UPGRADE, '').lower() != 'websocket':
                raise WSServerHandshakeError(
                    resp.request_info,
                    resp.history,
                    message='Invalid upgrade header',
                    code=resp.status,
                    headers=resp.headers)

            if resp.headers.get(hdrs.CONNECTION, '').lower() != 'upgrade':
                raise WSServerHandshakeError(
                    resp.request_info,
                    resp.history,
                    message='Invalid connection header',
                    code=resp.status,
                    headers=resp.headers)

            # key calculation
            key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '')
            match = base64.b64encode(
                hashlib.sha1(sec_key + WS_KEY).digest()).decode()
            if key != match:
                raise WSServerHandshakeError(
                    resp.request_info,
                    resp.history,
                    message='Invalid challenge response',
                    code=resp.status,
                    headers=resp.headers)

            # websocket protocol
            protocol = None
            if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers:
                resp_protocols = [
                    proto.strip() for proto in
                    resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]

                for proto in resp_protocols:
                    if proto in protocols:
                        protocol = proto
                        break

            # websocket compress
            notakeover = False
            if compress:
                compress_hdrs = resp.headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS)
                if compress_hdrs:
                    try:
                        compress, notakeover = ws_ext_parse(compress_hdrs)
                    except WSHandshakeError as exc:
                        raise WSServerHandshakeError(
                            resp.request_info,
                            resp.history,
                            message=exc.args[0],
                            code=resp.status,
                            headers=resp.headers)
                else:
                    compress = 0
                    notakeover = False

            proto = resp.connection.protocol
            transport = resp.connection.transport
            reader = FlowControlDataQueue(
                proto, limit=2 ** 16, loop=self._loop)
            proto.set_parser(WebSocketReader(reader), reader)
            tcp_nodelay(transport, True)
            writer = WebSocketWriter(
                proto, transport, use_mask=True,
                compress=compress, notakeover=notakeover)
        except BaseException:
            resp.close()
            raise
        else:
            return self._ws_response_class(reader,
                                           writer,
                                           protocol,
                                           resp,
                                           timeout,
                                           autoclose,
                                           autoping,
                                           self._loop,
                                           receive_timeout=receive_timeout,
                                           heartbeat=heartbeat,
                                           compress=compress,
                                           client_notakeover=notakeover)
示例#45
0
class StreamResponse(HeadersMixin):

    def __init__(self, *, status=200, reason=None, headers=None):
        self._body = None
        self._keep_alive = None
        self._chunked = False
        self._chunk_size = None
        self._compression = False
        self._compression_force = False
        self._headers = CIMultiDict()
        self._cookies = SimpleCookie()

        self._req = None
        self._resp_impl = None
        self._eof_sent = False
        self._body_length = 0

        if headers is not None:
            # TODO: optimize CIMultiDict extending
            self._headers.extend(headers)
        self._headers.setdefault(hdrs.CONTENT_TYPE, 'application/octet-stream')

        self.set_status(status, reason)

    @property
    def prepared(self):
        return self._resp_impl is not None

    @property
    def started(self):
        warnings.warn('use Response.prepared instead', DeprecationWarning)
        return self.prepared

    @property
    def task(self):
        return getattr(self._req, 'task', None)

    @property
    def status(self):
        return self._status

    @property
    def chunked(self):
        return self._chunked

    @property
    def compression(self):
        return self._compression

    @property
    def reason(self):
        return self._reason

    def set_status(self, status, reason=None):
        if self.prepared:
            raise RuntimeError("Cannot change the response status code after "
                               "the headers have been sent")
        self._status = int(status)
        if reason is None:
            reason = ResponseImpl.calc_reason(status)
        self._reason = reason

    @property
    def keep_alive(self):
        return self._keep_alive

    def force_close(self):
        self._keep_alive = False

    @property
    def body_length(self):
        return self._body_length

    @property
    def output_length(self):
        return self._resp_impl.output_length

    def enable_chunked_encoding(self, chunk_size=None):
        """Enables automatic chunked transfer encoding."""
        self._chunked = True
        self._chunk_size = chunk_size

    def enable_compression(self, force=None):
        """Enables response compression encoding."""
        # Backwards compatibility for when force was a bool <0.17.
        if type(force) == bool:
            force = ContentCoding.deflate if force else ContentCoding.identity
        elif force is not None:
            assert isinstance(force, ContentCoding), ("force should one of "
                                                      "None, bool or "
                                                      "ContentEncoding")

        self._compression = True
        self._compression_force = force

    @property
    def headers(self):
        return self._headers

    @property
    def cookies(self):
        return self._cookies

    def set_cookie(self, name, value, *, expires=None,
                   domain=None, max_age=None, path='/',
                   secure=None, httponly=None, version=None):
        """Set or update response cookie.

        Sets new cookie or updates existent with new value.
        Also updates only those params which are not None.
        """

        old = self._cookies.get(name)
        if old is not None and old.coded_value == '':
            # deleted cookie
            self._cookies.pop(name, None)

        self._cookies[name] = value
        c = self._cookies[name]

        if expires is not None:
            c['expires'] = expires
        elif c.get('expires') == 'Thu, 01 Jan 1970 00:00:00 GMT':
            del c['expires']

        if domain is not None:
            c['domain'] = domain

        if max_age is not None:
            c['max-age'] = max_age
        elif 'max-age' in c:
            del c['max-age']

        c['path'] = path

        if secure is not None:
            c['secure'] = secure
        if httponly is not None:
            c['httponly'] = httponly
        if version is not None:
            c['version'] = version

    def del_cookie(self, name, *, domain=None, path='/'):
        """Delete cookie.

        Creates new empty expired cookie.
        """
        # TODO: do we need domain/path here?
        self._cookies.pop(name, None)
        self.set_cookie(name, '', max_age=0,
                        expires="Thu, 01 Jan 1970 00:00:00 GMT",
                        domain=domain, path=path)

    @property
    def content_length(self):
        # Just a placeholder for adding setter
        return super().content_length

    @content_length.setter
    def content_length(self, value):
        if value is not None:
            value = int(value)
            # TODO: raise error if chunked enabled
            self.headers[hdrs.CONTENT_LENGTH] = str(value)
        else:
            self.headers.pop(hdrs.CONTENT_LENGTH, None)

    @property
    def content_type(self):
        # Just a placeholder for adding setter
        return super().content_type

    @content_type.setter
    def content_type(self, value):
        self.content_type  # read header values if needed
        self._content_type = str(value)
        self._generate_content_type_header()

    @property
    def charset(self):
        # Just a placeholder for adding setter
        return super().charset

    @charset.setter
    def charset(self, value):
        ctype = self.content_type  # read header values if needed
        if ctype == 'application/octet-stream':
            raise RuntimeError("Setting charset for application/octet-stream "
                               "doesn't make sense, setup content_type first")
        if value is None:
            self._content_dict.pop('charset', None)
        else:
            self._content_dict['charset'] = str(value).lower()
        self._generate_content_type_header()

    @property
    def last_modified(self, _LAST_MODIFIED=hdrs.LAST_MODIFIED):
        """The value of Last-Modified HTTP header, or None.

        This header is represented as a `datetime` object.
        """
        httpdate = self.headers.get(_LAST_MODIFIED)
        if httpdate is not None:
            timetuple = parsedate(httpdate)
            if timetuple is not None:
                return datetime.datetime(*timetuple[:6],
                                         tzinfo=datetime.timezone.utc)
        return None

    @last_modified.setter
    def last_modified(self, value):
        if value is None:
            self.headers.pop(hdrs.LAST_MODIFIED, None)
        elif isinstance(value, (int, float)):
            self.headers[hdrs.LAST_MODIFIED] = time.strftime(
                "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value)))
        elif isinstance(value, datetime.datetime):
            self.headers[hdrs.LAST_MODIFIED] = time.strftime(
                "%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple())
        elif isinstance(value, str):
            self.headers[hdrs.LAST_MODIFIED] = value

    @property
    def tcp_nodelay(self):
        resp_impl = self._resp_impl
        if resp_impl is None:
            raise RuntimeError("Cannot get tcp_nodelay for "
                               "not prepared response")
        return resp_impl.transport.tcp_nodelay

    def set_tcp_nodelay(self, value):
        resp_impl = self._resp_impl
        if resp_impl is None:
            raise RuntimeError("Cannot set tcp_nodelay for "
                               "not prepared response")
        resp_impl.transport.set_tcp_nodelay(value)

    @property
    def tcp_cork(self):
        resp_impl = self._resp_impl
        if resp_impl is None:
            raise RuntimeError("Cannot get tcp_cork for "
                               "not prepared response")
        return resp_impl.transport.tcp_cork

    def set_tcp_cork(self, value):
        resp_impl = self._resp_impl
        if resp_impl is None:
            raise RuntimeError("Cannot set tcp_cork for "
                               "not prepared response")
        resp_impl.transport.set_tcp_cork(value)

    def _generate_content_type_header(self, CONTENT_TYPE=hdrs.CONTENT_TYPE):
        params = '; '.join("%s=%s" % i for i in self._content_dict.items())
        if params:
            ctype = self._content_type + '; ' + params
        else:
            ctype = self._content_type
        self.headers[CONTENT_TYPE] = ctype

    def _start_pre_check(self, request):
        if self._resp_impl is not None:
            if self._req is not request:
                raise RuntimeError(
                    "Response has been started with different request.")
            else:
                return self._resp_impl
        else:
            return None

    def _do_start_compression(self, coding):
        if coding != ContentCoding.identity:
            self.headers[hdrs.CONTENT_ENCODING] = coding.value
            self._resp_impl.add_compression_filter(coding.value)
            self.content_length = None

    def _start_compression(self, request):
        if self._compression_force:
            self._do_start_compression(self._compression_force)
        else:
            accept_encoding = request.headers.get(
                hdrs.ACCEPT_ENCODING, '').lower()
            for coding in ContentCoding:
                if coding.value in accept_encoding:
                    self._do_start_compression(coding)
                    return

    def start(self, request):
        warnings.warn('use .prepare(request) instead', DeprecationWarning)
        resp_impl = self._start_pre_check(request)
        if resp_impl is not None:
            return resp_impl

        return self._start(request)

    @asyncio.coroutine
    def prepare(self, request):
        resp_impl = self._start_pre_check(request)
        if resp_impl is not None:
            return resp_impl
        yield from request._prepare_hook(self)

        return self._start(request)

    def _start(self, request,
               HttpVersion10=HttpVersion10,
               HttpVersion11=HttpVersion11,
               CONNECTION=hdrs.CONNECTION,
               DATE=hdrs.DATE,
               SERVER=hdrs.SERVER,
               SET_COOKIE=hdrs.SET_COOKIE,
               TRANSFER_ENCODING=hdrs.TRANSFER_ENCODING):
        self._req = request
        keep_alive = self._keep_alive
        if keep_alive is None:
            keep_alive = request.keep_alive
        self._keep_alive = keep_alive
        version = request.version

        resp_impl = self._resp_impl = ResponseImpl(
            request._writer,
            self._status,
            version,
            not keep_alive,
            self._reason)

        headers = self.headers
        for cookie in self._cookies.values():
            value = cookie.output(header='')[1:]
            headers.add(SET_COOKIE, value)

        if self._compression:
            self._start_compression(request)

        if self._chunked:
            if request.version != HttpVersion11:
                raise RuntimeError("Using chunked encoding is forbidden "
                                   "for HTTP/{0.major}.{0.minor}".format(
                                       request.version))
            resp_impl.chunked = True
            if self._chunk_size:
                resp_impl.add_chunking_filter(self._chunk_size)
            headers[TRANSFER_ENCODING] = 'chunked'
        else:
            resp_impl.length = self.content_length

        headers.setdefault(DATE, request.time_service.strtime())
        headers.setdefault(SERVER, resp_impl.SERVER_SOFTWARE)
        if CONNECTION not in headers:
            if keep_alive:
                if version == HttpVersion10:
                    headers[CONNECTION] = 'keep-alive'
            else:
                if version == HttpVersion11:
                    headers[CONNECTION] = 'close'

        resp_impl.headers = headers

        self._send_headers(resp_impl)
        return resp_impl

    def _send_headers(self, resp_impl):
        # Durty hack required for
        # https://github.com/KeepSafe/aiohttp/issues/1093
        # File sender may override it
        resp_impl.send_headers()

    def write(self, data):
        assert isinstance(data, (bytes, bytearray, memoryview)), \
            "data argument must be byte-ish (%r)" % type(data)

        if self._eof_sent:
            raise RuntimeError("Cannot call write() after write_eof()")
        if self._resp_impl is None:
            raise RuntimeError("Cannot call write() before start()")

        if data:
            return self._resp_impl.write(data)
        else:
            return ()

    @asyncio.coroutine
    def drain(self):
        if self._resp_impl is None:
            raise RuntimeError("Response has not been started")
        yield from self._resp_impl.transport.drain()

    @asyncio.coroutine
    def write_eof(self):
        if self._eof_sent:
            return
        if self._resp_impl is None:
            raise RuntimeError("Response has not been started")

        yield from self._resp_impl.write_eof()
        self._eof_sent = True
        self._body_length = self._resp_impl.body_length
        self._req = None
        self._resp_impl = None

    def __repr__(self):
        if self._eof_sent:
            info = "eof"
        elif self.started:
            info = "{} {} ".format(self._req.method, self._req.path)
        else:
            info = "not started"
        return "<{} {} {}>".format(self.__class__.__name__,
                                   self.reason, info)
示例#46
0
async def test_keep_alive_http09():
    headers = CIMultiDict(Connection='keep-alive')
    req = make_request('GET', '/', version=HttpVersion(0, 9), headers=headers)
    resp = StreamResponse()
    await resp.prepare(req)
    assert not resp.keep_alive
示例#47
0
class ClientRequest:
    GET_METHODS = {
        hdrs.METH_GET,
        hdrs.METH_HEAD,
        hdrs.METH_OPTIONS,
        hdrs.METH_TRACE,
    }
    POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    ALL_METHODS = GET_METHODS.union(POST_METHODS).union({hdrs.METH_DELETE})

    DEFAULT_HEADERS = {
        hdrs.ACCEPT: "*/*",
        hdrs.ACCEPT_ENCODING: "gzip, deflate",
    }

    body = b""
    auth = None
    response = None

    _writer = None  # async task for streaming data
    _continue = None  # waiter future for '100 Continue' response

    # N.B.
    # Adding __del__ method with self._writer closing doesn't make sense
    # because _writer is instance method, thus it keeps a reference to self.
    # Until writer has finished finalizer will not be called.

    def __init__(
        self,
        method: str,
        url: URL,
        *,
        params: Optional[Mapping[str, str]] = None,
        headers: Optional[LooseHeaders] = None,
        skip_auto_headers: Iterable[str] = frozenset(),
        data: Any = None,
        cookies: Optional[LooseCookies] = None,
        auth: Optional[BasicAuth] = None,
        version: http.HttpVersion = http.HttpVersion11,
        compress: Optional[str] = None,
        chunked: Optional[bool] = None,
        expect100: bool = False,
        loop: Optional[asyncio.AbstractEventLoop] = None,
        response_class: Optional[Type["ClientResponse"]] = None,
        proxy: Optional[URL] = None,
        proxy_auth: Optional[BasicAuth] = None,
        timer: Optional[BaseTimerContext] = None,
        session: Optional["ClientSession"] = None,
        ssl: Union[SSLContext, bool, Fingerprint, None] = None,
        proxy_headers: Optional[LooseHeaders] = None,
        traces: Optional[List["Trace"]] = None,
    ):

        if loop is None:
            loop = asyncio.get_event_loop()

        assert isinstance(url, URL), url
        assert isinstance(proxy, (URL, type(None))), proxy
        # FIXME: session is None in tests only, need to fix tests
        # assert session is not None
        self._session = cast("ClientSession", session)
        if params:
            q = MultiDict(url.query)
            url2 = url.with_query(params)
            q.extend(url2.query)
            url = url.with_query(q)
        self.original_url = url
        self.url = url.with_fragment(None)
        self.method = method.upper()
        self.chunked = chunked
        self.compress = compress
        self.loop = loop
        self.length = None
        if response_class is None:
            real_response_class = ClientResponse
        else:
            real_response_class = response_class
        self.response_class = real_response_class  # type: Type[ClientResponse]
        self._timer = timer if timer is not None else TimerNoop()
        self._ssl = ssl

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        self.update_version(version)
        self.update_host(url)
        self.update_headers(headers)
        self.update_auto_headers(skip_auto_headers)
        self.update_cookies(cookies)
        self.update_content_encoding(data)
        self.update_auth(auth)
        self.update_proxy(proxy, proxy_auth, proxy_headers)

        self.update_body_from_data(data)
        if data or self.method not in self.GET_METHODS:
            self.update_transfer_encoding()
        self.update_expect_continue(expect100)
        if traces is None:
            traces = []
        self._traces = traces

    def is_ssl(self) -> bool:
        return self.url.scheme in ("https", "wss")

    @property
    def ssl(self) -> Union["SSLContext", None, bool, Fingerprint]:
        return self._ssl

    @property
    def connection_key(self) -> ConnectionKey:
        proxy_headers = self.proxy_headers
        if proxy_headers:
            h = hash(tuple(
                (k, v)
                for k, v in proxy_headers.items()))  # type: Optional[int]
        else:
            h = None
        return ConnectionKey(
            self.host,
            self.port,
            self.is_ssl(),
            self.ssl,
            self.proxy,
            self.proxy_auth,
            h,
        )

    @property
    def host(self) -> str:
        ret = self.url.raw_host
        assert ret is not None
        return ret

    @property
    def port(self) -> Optional[int]:
        return self.url.port

    @property
    def request_info(self) -> RequestInfo:
        headers = CIMultiDictProxy(self.headers)  # type: CIMultiDictProxy[str]
        return RequestInfo(self.url, self.method, headers, self.original_url)

    def update_host(self, url: URL) -> None:
        """Update destination host, port and connection type (ssl)."""
        # get host/port
        if not url.raw_host:
            raise InvalidURL(url)

        # basic auth info
        username, password = url.user, url.password
        if username:
            self.auth = helpers.BasicAuth(username, password or "")

    def update_version(self, version: Union[http.HttpVersion, str]) -> None:
        """Convert request version to two elements tuple.

        parser HTTP version '1.1' => (1, 1)
        """
        if isinstance(version, str):
            v = [part.strip() for part in version.split(".", 1)]
            try:
                version = http.HttpVersion(int(v[0]), int(v[1]))
            except ValueError:
                raise ValueError(
                    f"Can not parse http version number: {version}") from None
        self.version = version

    def update_headers(self, headers: Optional[LooseHeaders]) -> None:
        """Update request headers."""
        self.headers = CIMultiDict()  # type: CIMultiDict[str]

        # add host
        netloc = cast(str, self.url.raw_host)
        if helpers.is_ipv6_address(netloc):
            netloc = f"[{netloc}]"
        if self.url.port is not None and not self.url.is_default_port():
            netloc += ":" + str(self.url.port)
        self.headers[hdrs.HOST] = netloc

        if headers:
            if isinstance(headers, (dict, MultiDictProxy, MultiDict)):
                headers = headers.items()  # type: ignore

            for key, value in headers:  # type: ignore
                # A special case for Host header
                if key.lower() == "host":
                    self.headers[key] = value
                else:
                    self.headers.add(key, value)

    def update_auto_headers(self, skip_auto_headers: Iterable[str]) -> None:
        self.skip_auto_headers = CIMultiDict(
            (hdr, None) for hdr in sorted(skip_auto_headers))
        used_headers = self.headers.copy()
        used_headers.extend(self.skip_auto_headers)  # type: ignore

        for hdr, val in self.DEFAULT_HEADERS.items():
            if hdr not in used_headers:
                self.headers.add(hdr, val)

        if hdrs.USER_AGENT not in used_headers:
            self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE

    def update_cookies(self, cookies: Optional[LooseCookies]) -> None:
        """Update request cookies header."""
        if not cookies:
            return

        c = SimpleCookie()  # type: SimpleCookie[str]
        if hdrs.COOKIE in self.headers:
            c.load(self.headers.get(hdrs.COOKIE, ""))
            del self.headers[hdrs.COOKIE]

        if isinstance(cookies, Mapping):
            iter_cookies = cookies.items()
        else:
            iter_cookies = cookies  # type: ignore
        for name, value in iter_cookies:
            if isinstance(value, Morsel):
                # Preserve coded_value
                mrsl_val = value.get(value.key, Morsel())
                mrsl_val.set(value.key, value.value, value.coded_value)
                c[name] = mrsl_val
            else:
                c[name] = value  # type: ignore

        self.headers[hdrs.COOKIE] = c.output(header="", sep=";").strip()

    def update_content_encoding(self, data: Any) -> None:
        """Set request content encoding."""
        if not data:
            return

        enc = self.headers.get(hdrs.CONTENT_ENCODING, "").lower()
        if enc:
            if self.compress:
                raise ValueError("compress can not be set "
                                 "if Content-Encoding header is set")
        elif self.compress:
            if not isinstance(self.compress, str):
                self.compress = "deflate"
            self.headers[hdrs.CONTENT_ENCODING] = self.compress
            self.chunked = True  # enable chunked, no need to deal with length

    def update_transfer_encoding(self) -> None:
        """Analyze transfer-encoding header."""
        te = self.headers.get(hdrs.TRANSFER_ENCODING, "").lower()

        if "chunked" in te:
            if self.chunked:
                raise ValueError(
                    "chunked can not be set "
                    'if "Transfer-Encoding: chunked" header is set')

        elif self.chunked:
            if hdrs.CONTENT_LENGTH in self.headers:
                raise ValueError("chunked can not be set "
                                 "if Content-Length header is set")

            self.headers[hdrs.TRANSFER_ENCODING] = "chunked"
        else:
            if hdrs.CONTENT_LENGTH not in self.headers:
                self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_auth(self, auth: Optional[BasicAuth]) -> None:
        """Set basic auth."""
        if auth is None:
            auth = self.auth
        if auth is None:
            return

        if not isinstance(auth, helpers.BasicAuth):
            raise TypeError("BasicAuth() tuple is required instead")

        self.headers[hdrs.AUTHORIZATION] = auth.encode()

    def update_body_from_data(self, body: Any) -> None:
        if not body:
            return

        # FormData
        if isinstance(body, FormData):
            body = body()

        try:
            body = payload.PAYLOAD_REGISTRY.get(body, disposition=None)
        except payload.LookupError:
            body = FormData(body)()

        self.body = body

        # enable chunked encoding if needed
        if not self.chunked:
            if hdrs.CONTENT_LENGTH not in self.headers:
                size = body.size
                if size is None:
                    self.chunked = True
                else:
                    if hdrs.CONTENT_LENGTH not in self.headers:
                        self.headers[hdrs.CONTENT_LENGTH] = str(size)

        # copy payload headers
        assert body.headers
        for (key, value) in body.headers.items():
            if key in self.headers:
                continue
            if key in self.skip_auto_headers:
                continue
            self.headers[key] = value

    def update_expect_continue(self, expect: bool = False) -> None:
        if expect:
            self.headers[hdrs.EXPECT] = "100-continue"
        elif self.headers.get(hdrs.EXPECT, "").lower() == "100-continue":
            expect = True

        if expect:
            self._continue = self.loop.create_future()

    def update_proxy(
        self,
        proxy: Optional[URL],
        proxy_auth: Optional[BasicAuth],
        proxy_headers: Optional[LooseHeaders],
    ) -> None:
        if proxy and not proxy.scheme == "http":
            raise ValueError("Only http proxies are supported")
        if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
            raise ValueError("proxy_auth must be None or BasicAuth() tuple")
        self.proxy = proxy
        self.proxy_auth = proxy_auth
        self.proxy_headers = proxy_headers

    def keep_alive(self) -> bool:
        if self.version < HttpVersion10:
            # keep alive not supported at all
            return False
        if self.version == HttpVersion10:
            if self.headers.get(hdrs.CONNECTION) == "keep-alive":
                return True
            else:  # no headers means we close for Http 1.0
                return False
        elif self.headers.get(hdrs.CONNECTION) == "close":
            return False

        return True

    async def write_bytes(self, writer: AbstractStreamWriter,
                          conn: "Connection") -> None:
        """Support coroutines that yields bytes objects."""
        # 100 response
        if self._continue is not None:
            await writer.drain()
            await self._continue

        protocol = conn.protocol
        assert protocol is not None
        try:
            if isinstance(self.body, payload.Payload):
                await self.body.write(writer)
            else:
                if isinstance(self.body, (bytes, bytearray)):
                    self.body = (self.body, )  # type: ignore

                for chunk in self.body:
                    await writer.write(chunk)  # type: ignore

            await writer.write_eof()
        except OSError as exc:
            new_exc = ClientOSError(
                exc.errno, "Can not write request body for %s" % self.url)
            new_exc.__context__ = exc
            new_exc.__cause__ = exc
            protocol.set_exception(new_exc)
        except asyncio.CancelledError as exc:
            if not conn.closed:
                protocol.set_exception(exc)
        except Exception as exc:
            protocol.set_exception(exc)
        finally:
            self._writer = None

    async def send(self, conn: "Connection") -> "ClientResponse":
        # Specify request target:
        # - CONNECT request must send authority form URI
        # - not CONNECT proxy must send absolute form URI
        # - most common is origin form URI
        if self.method == hdrs.METH_CONNECT:
            connect_host = self.url.raw_host
            assert connect_host is not None
            if helpers.is_ipv6_address(connect_host):
                connect_host = f"[{connect_host}]"
            path = f"{connect_host}:{self.url.port}"
        elif self.proxy and not self.is_ssl():
            path = str(self.url)
        else:
            path = self.url.raw_path
            if self.url.raw_query_string:
                path += "?" + self.url.raw_query_string

        protocol = conn.protocol
        assert protocol is not None
        writer = StreamWriter(
            protocol,
            self.loop,
            on_chunk_sent=functools.partial(self._on_chunk_request_sent,
                                            self.method, self.url),
        )

        if self.compress:
            writer.enable_compression(self.compress)

        if self.chunked is not None:
            writer.enable_chunking()

        # set default content-type
        if (self.method in self.POST_METHODS
                and hdrs.CONTENT_TYPE not in self.skip_auto_headers
                and hdrs.CONTENT_TYPE not in self.headers):
            self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream"

        # set the connection header
        connection = self.headers.get(hdrs.CONNECTION)
        if not connection:
            if self.keep_alive():
                if self.version == HttpVersion10:
                    connection = "keep-alive"
            else:
                if self.version == HttpVersion11:
                    connection = "close"

        if connection is not None:
            self.headers[hdrs.CONNECTION] = connection

        # status + headers
        status_line = "{0} {1} HTTP/{2[0]}.{2[1]}".format(
            self.method, path, self.version)
        await writer.write_headers(status_line, self.headers)

        self._writer = self.loop.create_task(self.write_bytes(writer, conn))

        response_class = self.response_class
        assert response_class is not None
        self.response = response_class(
            self.method,
            self.original_url,
            writer=self._writer,
            continue100=self._continue,
            timer=self._timer,
            request_info=self.request_info,
            traces=self._traces,
            loop=self.loop,
            session=self._session,
        )
        return self.response

    async def close(self) -> None:
        if self._writer is not None:
            try:
                await self._writer
            finally:
                self._writer = None

    def terminate(self) -> None:
        if self._writer is not None:
            if not self.loop.is_closed():
                self._writer.cancel()
            self._writer = None

    async def _on_chunk_request_sent(self, method: str, url: URL,
                                     chunk: bytes) -> None:
        for trace in self._traces:
            await trace.send_request_chunk_sent(method, url, chunk)
示例#48
0
def test_body_in_ctor_with_content_type_header_multidict():
    headers = CIMultiDict({'Content-Type': 'text/html; charset=koi8-r'})
    resp = Response(body='текст'.encode('koi8-r'), headers=headers)
    assert 'текст'.encode('koi8-r') == resp.body
    assert 'text/html' == resp.content_type
    assert 'koi8-r' == resp.charset
示例#49
0
def make_mocked_request(method,
                        path,
                        headers=None,
                        *,
                        version=HttpVersion(1, 1),
                        closing=False,
                        app=None,
                        writer=sentinel,
                        payload_writer=sentinel,
                        protocol=sentinel,
                        transport=sentinel,
                        payload=sentinel,
                        sslcontext=None,
                        secure_proxy_ssl_header=None,
                        client_max_size=1024**2):
    """Creates mocked web.Request testing purposes.

    Useful in unit tests, when spinning full web server is overkill or
    specific conditions and errors are hard to trigger.

    """

    task = mock.Mock()
    loop = mock.Mock()
    loop.create_future.return_value = ()

    if version < HttpVersion(1, 1):
        closing = True

    if headers:
        headers = CIMultiDict(headers)
        raw_hdrs = tuple(
            (k.encode('utf-8'), v.encode('utf-8')) for k, v in headers.items())
    else:
        headers = CIMultiDict()
        raw_hdrs = ()

    chunked = 'chunked' in headers.get(hdrs.TRANSFER_ENCODING, '').lower()

    message = RawRequestMessage(method, path, version, headers, raw_hdrs,
                                closing, False, False, chunked, URL(path))
    if app is None:
        app = _create_app_mock()

    if protocol is sentinel:
        protocol = mock.Mock()

    if transport is sentinel:
        transport = _create_transport(sslcontext)

    if writer is sentinel:
        writer = mock.Mock()
        writer.transport = transport

    if payload_writer is sentinel:
        payload_writer = mock.Mock()
        payload_writer.write_eof.side_effect = noop
        payload_writer.drain.side_effect = noop

    protocol.transport = transport
    protocol.writer = writer

    if payload is sentinel:
        payload = mock.Mock()

    time_service = mock.Mock()
    time_service.time.return_value = 12345
    time_service.strtime.return_value = "Tue, 15 Nov 1994 08:12:31 GMT"

    @contextmanager
    def timeout(*args, **kw):
        yield

    time_service.timeout = mock.Mock()
    time_service.timeout.side_effect = timeout

    req = Request(message,
                  payload,
                  protocol,
                  payload_writer,
                  time_service,
                  task,
                  secure_proxy_ssl_header=secure_proxy_ssl_header,
                  client_max_size=client_max_size)

    match_info = UrlMappingMatchInfo({}, mock.Mock())
    match_info.add_app(app)
    req._match_info = match_info

    return req
示例#50
0
def make_mocked_request(method,
                        path,
                        headers=None,
                        *,
                        version=HttpVersion(1, 1),
                        closing=False,
                        app=None,
                        reader=sentinel,
                        writer=sentinel,
                        transport=sentinel,
                        payload=sentinel,
                        sslcontext=None,
                        secure_proxy_ssl_header=None):
    """Creates mocked web.Request testing purposes.

    Useful in unit tests, when spinning full web server is overkill or
    specific conditions and errors are hard to trigger.

    :param method: str, that represents HTTP method, like; GET, POST.
    :type method: str

    :param path: str, The URL including *PATH INFO* without the host or scheme
    :type path: str

    :param headers: mapping containing the headers. Can be anything accepted
        by the multidict.CIMultiDict constructor.
    :type headers: dict, multidict.CIMultiDict, list of pairs

    :param version: namedtuple with encoded HTTP version
    :type version: aiohttp.protocol.HttpVersion

    :param closing: flag indicates that connection should be closed after
        response.
    :type closing: bool

    :param app: the aiohttp.web application attached for fake request
    :type app: aiohttp.web.Application

    :param reader: object for storing and managing incoming data
    :type reader: aiohttp.parsers.StreamParser

    :param writer: object for managing outcoming data
    :type wirter: aiohttp.parsers.StreamWriter

    :param transport: asyncio transport instance
    :type transport: asyncio.transports.Transport

    :param payload: raw payload reader object
    :type  payload: aiohttp.streams.FlowControlStreamReader

    :param sslcontext: ssl.SSLContext object, for HTTPS connection
    :type sslcontext: ssl.SSLContext

    :param secure_proxy_ssl_header: A tuple representing a HTTP header/value
        combination that signifies a request is secure.
    :type secure_proxy_ssl_header: tuple

    """

    if version < HttpVersion(1, 1):
        closing = True

    if headers:
        hdrs = CIMultiDict(headers)
        raw_hdrs = [(k.encode('utf-8'), v.encode('utf-8'))
                    for k, v in headers.items()]
    else:
        hdrs = CIMultiDict()
        raw_hdrs = []

    message = RawRequestMessage(method, path, version, hdrs, raw_hdrs, closing,
                                False)
    if app is None:
        app = _create_app_mock()

    if reader is sentinel:
        reader = mock.Mock()

    if writer is sentinel:
        writer = mock.Mock()

    if transport is sentinel:
        transport = _create_transport(sslcontext)

    if payload is sentinel:
        payload = mock.Mock()

    req = Request(app,
                  message,
                  payload,
                  transport,
                  reader,
                  writer,
                  secure_proxy_ssl_header=secure_proxy_ssl_header)

    return req
示例#51
0
def test_raw_headers() -> None:
    req = make_mocked_request("GET", "/", headers=CIMultiDict({"X-HEADER": "aaa"}))
    assert req.raw_headers == ((b"X-HEADER", b"aaa"),)
示例#52
0
class ClientRequest:

    GET_METHODS = {hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS}
    POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    ALL_METHODS = GET_METHODS.union(POST_METHODS).union(
        {hdrs.METH_DELETE, hdrs.METH_TRACE})

    DEFAULT_HEADERS = {
        hdrs.ACCEPT: '*/*',
        hdrs.ACCEPT_ENCODING: 'gzip, deflate',
    }

    SERVER_SOFTWARE = HttpMessage.SERVER_SOFTWARE

    body = b''
    auth = None
    response = None
    response_class = None

    _writer = None  # async task for streaming data
    _continue = None  # waiter future for '100 Continue' response

    # N.B.
    # Adding __del__ method with self._writer closing doesn't make sense
    # because _writer is instance method, thus it keeps a reference to self.
    # Until writer has finished finalizer will not be called.

    def __init__(self, method, url, *,
                 params=None, headers=None, skip_auto_headers=frozenset(),
                 data=None, cookies=None,
                 auth=None, encoding='utf-8',
                 version=aiohttp.HttpVersion11, compress=None,
                 chunked=None, expect100=False,
                 loop=None, response_class=None,
                 proxy=None, proxy_auth=None, timer=None):

        if loop is None:
            loop = asyncio.get_event_loop()

        assert isinstance(url, URL), url
        assert isinstance(proxy, (URL, type(None))), proxy

        if params:
            q = MultiDict(url.query)
            url2 = url.with_query(params)
            q.extend(url2.query)
            url = url.with_query(q)
        self.url = url.with_fragment(None)
        self.original_url = url
        self.method = method.upper()
        self.encoding = encoding
        self.chunked = chunked
        self.compress = compress
        self.loop = loop
        self.response_class = response_class or ClientResponse
        self._timer = timer if timer is not None else _TimeServiceTimeoutNoop()

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        self.update_version(version)
        self.update_host(url)
        self.update_headers(headers)
        self.update_auto_headers(skip_auto_headers)
        self.update_cookies(cookies)
        self.update_content_encoding(data)
        self.update_auth(auth)
        self.update_proxy(proxy, proxy_auth)

        self.update_body_from_data(data, skip_auto_headers)
        self.update_transfer_encoding()
        self.update_expect_continue(expect100)

    @property
    def host(self):
        return self.url.host

    @property
    def port(self):
        return self.url.port

    def update_host(self, url):
        """Update destination host, port and connection type (ssl)."""
        # get host/port
        if not url.host:
            raise ValueError('Host could not be detected.')

        # basic auth info
        username, password = url.user, url.password
        if username:
            self.auth = helpers.BasicAuth(username, password or '')

        # Record entire netloc for usage in host header

        scheme = url.scheme
        self.ssl = scheme in ('https', 'wss')

    def update_version(self, version):
        """Convert request version to two elements tuple.

        parser HTTP version '1.1' => (1, 1)
        """
        if isinstance(version, str):
            v = [l.strip() for l in version.split('.', 1)]
            try:
                version = int(v[0]), int(v[1])
            except ValueError:
                raise ValueError(
                    'Can not parse http version number: {}'
                    .format(version)) from None
        self.version = version

    def update_headers(self, headers):
        """Update request headers."""
        self.headers = CIMultiDict()
        if headers:
            if isinstance(headers, dict):
                headers = headers.items()
            elif isinstance(headers, (MultiDictProxy, MultiDict)):
                headers = headers.items()

            for key, value in headers:
                self.headers.add(key, value)

    def update_auto_headers(self, skip_auto_headers):
        self.skip_auto_headers = skip_auto_headers
        used_headers = set(self.headers) | skip_auto_headers

        for hdr, val in self.DEFAULT_HEADERS.items():
            if hdr not in used_headers:
                self.headers.add(hdr, val)

        # add host
        if hdrs.HOST not in used_headers:
            netloc = self.url.raw_host
            if not self.url.is_default_port():
                netloc += ':' + str(self.url.port)
            self.headers[hdrs.HOST] = netloc

        if hdrs.USER_AGENT not in used_headers:
            self.headers[hdrs.USER_AGENT] = self.SERVER_SOFTWARE

    def update_cookies(self, cookies):
        """Update request cookies header."""
        if not cookies:
            return

        c = SimpleCookie()
        if hdrs.COOKIE in self.headers:
            c.load(self.headers.get(hdrs.COOKIE, ''))
            del self.headers[hdrs.COOKIE]

        for name, value in cookies.items():
            if isinstance(value, Morsel):
                # Preserve coded_value
                mrsl_val = value.get(value.key, Morsel())
                mrsl_val.set(value.key, value.value, value.coded_value)
                c[name] = mrsl_val
            else:
                c[name] = value

        self.headers[hdrs.COOKIE] = c.output(header='', sep=';').strip()

    def update_content_encoding(self, data):
        """Set request content encoding."""
        if not data:
            return

        enc = self.headers.get(hdrs.CONTENT_ENCODING, '').lower()
        if enc:
            if self.compress is not False:
                self.compress = enc
                # enable chunked, no need to deal with length
                self.chunked = True
        elif self.compress:
            if not isinstance(self.compress, str):
                self.compress = 'deflate'
            self.headers[hdrs.CONTENT_ENCODING] = self.compress
            self.chunked = True  # enable chunked, no need to deal with length

    def update_auth(self, auth):
        """Set basic auth."""
        if auth is None:
            auth = self.auth
        if auth is None:
            return

        if not isinstance(auth, helpers.BasicAuth):
            raise TypeError('BasicAuth() tuple is required instead')

        self.headers[hdrs.AUTHORIZATION] = auth.encode()

    def update_body_from_data(self, data, skip_auto_headers):
        if not data:
            return

        if isinstance(data, str):
            data = data.encode(self.encoding)

        if isinstance(data, (bytes, bytearray)):
            self.body = data
            if (hdrs.CONTENT_TYPE not in self.headers and
                    hdrs.CONTENT_TYPE not in skip_auto_headers):
                self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'
            if hdrs.CONTENT_LENGTH not in self.headers and not self.chunked:
                self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

        elif isinstance(data, (asyncio.StreamReader, streams.StreamReader,
                               streams.DataQueue)):
            self.body = data

        elif asyncio.iscoroutine(data):
            self.body = data
            if (hdrs.CONTENT_LENGTH not in self.headers and
                    self.chunked is None):
                self.chunked = True

        elif isinstance(data, io.IOBase):
            assert not isinstance(data, io.StringIO), \
                'attempt to send text data instead of binary'
            self.body = data
            if not self.chunked and isinstance(data, io.BytesIO):
                # Not chunking if content-length can be determined
                size = len(data.getbuffer())
                self.headers[hdrs.CONTENT_LENGTH] = str(size)
                self.chunked = False
            elif (not self.chunked and
                  isinstance(data, (io.BufferedReader, io.BufferedRandom))):
                # Not chunking if content-length can be determined
                try:
                    size = os.fstat(data.fileno()).st_size - data.tell()
                    self.headers[hdrs.CONTENT_LENGTH] = str(size)
                    self.chunked = False
                except OSError:
                    # data.fileno() is not supported, e.g.
                    # io.BufferedReader(io.BytesIO(b'data'))
                    self.chunked = True
            else:
                self.chunked = True

            if hasattr(data, 'mode'):
                if data.mode == 'r':
                    raise ValueError('file {!r} should be open in binary mode'
                                     ''.format(data))
            if (hdrs.CONTENT_TYPE not in self.headers and
                hdrs.CONTENT_TYPE not in skip_auto_headers and
                    hasattr(data, 'name')):
                mime = mimetypes.guess_type(data.name)[0]
                mime = 'application/octet-stream' if mime is None else mime
                self.headers[hdrs.CONTENT_TYPE] = mime

        elif isinstance(data, MultipartWriter):
            self.body = data.serialize()
            self.headers.update(data.headers)
            self.chunked = True

        else:
            if not isinstance(data, helpers.FormData):
                data = helpers.FormData(data)

            self.body = data(self.encoding)

            if (hdrs.CONTENT_TYPE not in self.headers and
                    hdrs.CONTENT_TYPE not in skip_auto_headers):
                self.headers[hdrs.CONTENT_TYPE] = data.content_type

            if data.is_multipart:
                self.chunked = True
            else:
                if (hdrs.CONTENT_LENGTH not in self.headers and
                        not self.chunked):
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_transfer_encoding(self):
        """Analyze transfer-encoding header."""
        te = self.headers.get(hdrs.TRANSFER_ENCODING, '').lower()

        if self.chunked:
            if hdrs.CONTENT_LENGTH in self.headers:
                del self.headers[hdrs.CONTENT_LENGTH]
            if 'chunked' not in te:
                self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'

        else:
            if 'chunked' in te:
                self.chunked = True
            else:
                self.chunked = None
                if hdrs.CONTENT_LENGTH not in self.headers:
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_expect_continue(self, expect=False):
        if expect:
            self.headers[hdrs.EXPECT] = '100-continue'
        elif self.headers.get(hdrs.EXPECT, '').lower() == '100-continue':
            expect = True

        if expect:
            self._continue = helpers.create_future(self.loop)

    def update_proxy(self, proxy, proxy_auth):
        if proxy and not proxy.scheme == 'http':
            raise ValueError("Only http proxies are supported")
        if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
            raise ValueError("proxy_auth must be None or BasicAuth() tuple")
        self.proxy = proxy
        self.proxy_auth = proxy_auth

    @asyncio.coroutine
    def write_bytes(self, request, conn):
        """Support coroutines that yields bytes objects."""
        # 100 response
        if self._continue is not None:
            yield from request.drain()
            yield from self._continue

        try:
            if asyncio.iscoroutine(self.body):
                exc = None
                value = None
                stream = self.body

                while True:
                    try:
                        if exc is not None:
                            result = stream.throw(exc)
                        else:
                            result = stream.send(value)
                    except StopIteration as exc:
                        if isinstance(exc.value, bytes):
                            yield from request.write(exc.value)
                        break
                    except:
                        self.response.close()
                        raise

                    if isinstance(result, asyncio.Future):
                        exc = None
                        value = None
                        try:
                            value = yield result
                        except Exception as err:
                            exc = err
                    elif isinstance(result, (bytes, bytearray)):
                        yield from request.write(result)
                        value = None
                    else:
                        raise ValueError(
                            'Bytes object is expected, got: %s.' %
                            type(result))

            elif isinstance(self.body, (asyncio.StreamReader,
                                        streams.StreamReader)):
                chunk = yield from self.body.read(streams.DEFAULT_LIMIT)
                while chunk:
                    yield from request.write(chunk, drain=True)
                    chunk = yield from self.body.read(streams.DEFAULT_LIMIT)

            elif isinstance(self.body, streams.DataQueue):
                while True:
                    try:
                        chunk = yield from self.body.read()
                        if not chunk:
                            break
                        yield from request.write(chunk)
                    except streams.EofStream:
                        break

            elif isinstance(self.body, io.IOBase):
                chunk = self.body.read(streams.DEFAULT_LIMIT)
                while chunk:
                    request.write(chunk)
                    chunk = self.body.read(self.chunked)
            else:
                if isinstance(self.body, (bytes, bytearray)):
                    self.body = (self.body,)

                for chunk in self.body:
                    request.write(chunk)

        except Exception as exc:
            new_exc = aiohttp.ClientRequestError(
                'Can not write request body for %s' % self.url)
            new_exc.__context__ = exc
            new_exc.__cause__ = exc
            conn.protocol.set_exception(new_exc)
        else:
            try:
                yield from request.write_eof()
            except Exception as exc:
                new_exc = aiohttp.ClientRequestError(
                    'Can not write request body for %s' % self.url)
                new_exc.__context__ = exc
                new_exc.__cause__ = exc
                conn.protocol.set_exception(new_exc)

        self._writer = None

    def send(self, conn):
        # Specify request target:
        # - CONNECT request must send authority form URI
        # - not CONNECT proxy must send absolute form URI
        # - most common is origin form URI
        if self.method == hdrs.METH_CONNECT:
            path = '{}:{}'.format(self.url.raw_host, self.url.port)
        elif self.proxy and not self.ssl:
            path = str(self.url)
        else:
            path = self.url.raw_path
            if self.url.raw_query_string:
                path += '?' + self.url.raw_query_string

        request = aiohttp.Request(
            conn.writer, self.method, path, self.version, loop=self.loop)

        if self.compress:
            request.enable_compression(self.compress)

        if self.chunked is not None:
            request.enable_chunking()

        # set default content-type
        if (self.method in self.POST_METHODS and
                hdrs.CONTENT_TYPE not in self.skip_auto_headers and
                hdrs.CONTENT_TYPE not in self.headers):
            self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'

        for k, value in self.headers.items():
            request.add_header(k, value)
        request.send_headers()

        self._writer = helpers.ensure_future(
            self.write_bytes(request, conn), loop=self.loop)

        self.response = self.response_class(
            self.method, self.original_url,
            writer=self._writer, continue100=self._continue, timer=self._timer)

        self.response._post_init(self.loop)
        return self.response

    @asyncio.coroutine
    def close(self):
        if self._writer is not None:
            try:
                yield from self._writer
            finally:
                self._writer = None

    def terminate(self):
        if self._writer is not None:
            if not self.loop.is_closed():
                self._writer.cancel()
            self._writer = None
示例#53
0
def test_merge_headers_with_list_of_tuples(create_session):
    session = create_session(headers={"h1": "header1", "h2": "header2"})
    headers = session._prepare_headers([("h1", "h1")])
    assert isinstance(headers, CIMultiDict)
    assert headers == CIMultiDict([("h2", "header2"), ("h1", "h1")])
示例#54
0
    "method", ["get", "post", "options", "post", "put", "patch", "delete"])
async def test_test_client_methods(method, loop, test_client) -> None:
    resp = await getattr(test_client, method)("/")
    assert resp.status == 200
    text = await resp.text()
    assert _hello_world_str == text


async def test_test_client_head(loop, test_client) -> None:
    resp = await test_client.head("/")
    assert resp.status == 200


@pytest.mark.parametrize("headers", [{
    'token': 'x'
}, CIMultiDict({'token': 'x'}), {}])
def test_make_mocked_request(headers) -> None:
    req = make_mocked_request('GET', '/', headers=headers)
    assert req.method == "GET"
    assert req.path == "/"
    assert isinstance(req, web.Request)
    assert isinstance(req.headers, CIMultiDictProxy)


def test_make_mocked_request_sslcontext() -> None:
    req = make_mocked_request('GET', '/')
    assert req.transport.get_extra_info('sslcontext') is None


def test_make_mocked_request_unknown_extra_info() -> None:
    req = make_mocked_request('GET', '/')
示例#55
0
def test_single_forwarded_header_injection2() -> None:
    header = "very bad syntax, for=_real"
    req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header}))
    assert len(req.forwarded) == 2
    assert "for" not in req.forwarded[0]
    assert req.forwarded[1]["for"] == "_real"
示例#56
0
def test_init_headers_list_of_tuples(create_session):
    session = create_session(
        headers=[("h1", "header1"), ("h2", "header2"), ("h3", "header3")])
    assert (session._default_headers == CIMultiDict([("h1", "header1"),
                                                     ("h2", "header2"),
                                                     ("h3", "header3")]))
示例#57
0
def test_init_headers_list_of_tuples_with_duplicates(create_session):
    session = create_session(
        headers=[("h1", "header11"), ("h2", "header21"), ("h1", "header12")])
    assert (session._default_headers == CIMultiDict([("H1", "header11"),
                                                     ("H2", "header21"),
                                                     ("H1", "header12")]))
示例#58
0
    def __init__(
            self,
            *,
            connector: Optional[BaseConnector] = None,
            loop: Optional[asyncio.AbstractEventLoop] = None,
            cookies: Optional[LooseCookies] = None,
            headers: Optional[LooseHeaders] = None,
            skip_auto_headers: Optional[Iterable[str]] = None,
            auth: Optional[BasicAuth] = None,
            json_serialize: JSONEncoder = json.dumps,
            request_class: Type[ClientRequest] = ClientRequest,
            response_class: Type[ClientResponse] = ClientResponse,
            ws_response_class: Type[
                ClientWebSocketResponse] = ClientWebSocketResponse,  # noqa
            version: HttpVersion = http.HttpVersion11,
            cookie_jar: Optional[AbstractCookieJar] = None,
            connector_owner: bool = True,
            raise_for_status: bool = False,
            read_timeout: Union[float, object] = sentinel,
            conn_timeout: Optional[float] = None,
            timeout: Union[object, ClientTimeout] = sentinel,
            auto_decompress: bool = True,
            trust_env: bool = False,
            requote_redirect_url: bool = True,
            trace_configs: Optional[List[TraceConfig]] = None) -> None:

        if loop is None:
            if connector is not None:
                loop = connector._loop

        loop = get_running_loop(loop)

        if connector is None:
            connector = TCPConnector(loop=loop)

        if connector._loop is not loop:
            raise RuntimeError(
                "Session and connector has to use same event loop")

        self._loop = loop

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        if cookie_jar is None:
            cookie_jar = CookieJar(loop=loop)
        self._cookie_jar = cookie_jar

        if cookies is not None:
            self._cookie_jar.update_cookies(cookies)

        self._connector = connector  # type: BaseConnector
        self._connector_owner = connector_owner
        self._default_auth = auth
        self._version = version
        self._json_serialize = json_serialize
        if timeout is sentinel:
            self._timeout = DEFAULT_TIMEOUT
            if read_timeout is not sentinel:
                warnings.warn(
                    "read_timeout is deprecated, "
                    "use timeout argument instead",
                    DeprecationWarning,
                    stacklevel=2)
                self._timeout = attr.evolve(self._timeout, total=read_timeout)
            if conn_timeout is not None:
                self._timeout = attr.evolve(self._timeout,
                                            connect=conn_timeout)
                warnings.warn(
                    "conn_timeout is deprecated, "
                    "use timeout argument instead",
                    DeprecationWarning,
                    stacklevel=2)
        else:
            self._timeout = timeout  # type: ignore
            if read_timeout is not sentinel:
                raise ValueError("read_timeout and timeout parameters "
                                 "conflict, please setup "
                                 "timeout.read")
            if conn_timeout is not None:
                raise ValueError("conn_timeout and timeout parameters "
                                 "conflict, please setup "
                                 "timeout.connect")
        self._raise_for_status = raise_for_status
        self._auto_decompress = auto_decompress
        self._trust_env = trust_env
        self._requote_redirect_url = requote_redirect_url

        # Convert to list of tuples
        if headers:
            headers = CIMultiDict(headers)
        else:
            headers = CIMultiDict()
        self._default_headers = headers
        if skip_auto_headers is not None:
            self._skip_auto_headers = frozenset(
                [istr(i) for i in skip_auto_headers])
        else:
            self._skip_auto_headers = frozenset()

        self._request_class = request_class
        self._response_class = response_class
        self._ws_response_class = ws_response_class

        self._trace_configs = trace_configs or []
        for trace_config in self._trace_configs:
            trace_config.freeze()
示例#59
0
class HttpMessage(ABC, PayloadWriter):
    """HttpMessage allows to write headers and payload to a stream."""

    HOP_HEADERS = None  # Must be set by subclass.

    SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} aiohttp/{1}'.format(
        sys.version_info, aiohttp.__version__)

    upgrade = False  # Connection: UPGRADE
    websocket = False  # Upgrade: WEBSOCKET
    has_chunked_hdr = False  # Transfer-encoding: chunked

    def __init__(self, transport, version, close, loop=None):
        super().__init__(transport, loop)

        self._version = version
        self.closing = close
        self.keepalive = None
        self.length = None
        self.headers = CIMultiDict()
        self.headers_sent = False

    @property
    @abstractmethod
    def status_line(self):
        return b''

    @abstractmethod
    def autochunked(self):
        return False

    @property
    def version(self):
        return self._version

    @property
    def body_length(self):
        return self.output_length

    def force_close(self):
        self.closing = True
        self.keepalive = False

    def keep_alive(self):
        if self.keepalive is None:
            if self.version < HttpVersion10:
                # keep alive not supported at all
                return False
            if self.version == HttpVersion10:
                if self.headers.get(hdrs.CONNECTION) == 'keep-alive':
                    return True
                else:  # no headers means we close for Http 1.0
                    return False
            else:
                return not self.closing
        else:
            return self.keepalive

    def is_headers_sent(self):
        return self.headers_sent

    def add_header(self, name, value):
        """Analyze headers. Calculate content length,
        removes hop headers, etc."""
        assert not self.headers_sent, 'headers have been sent already'
        assert isinstance(name, str), \
            'Header name should be a string, got {!r}'.format(name)
        assert set(name).issubset(ASCIISET), \
            'Header name should contain ASCII chars, got {!r}'.format(name)
        assert isinstance(value, str), \
            'Header {!r} should have string value, got {!r}'.format(
                name, value)

        name = istr(name)
        value = value.strip()

        if name == hdrs.CONTENT_LENGTH:
            self.length = int(value)

        if name == hdrs.TRANSFER_ENCODING:
            self.has_chunked_hdr = value.lower().strip() == 'chunked'

        if name == hdrs.CONNECTION:
            val = value.lower()
            # handle websocket
            if 'upgrade' in val:
                self.upgrade = True
            # connection keep-alive
            elif 'close' in val:
                self.keepalive = False
            elif 'keep-alive' in val:
                self.keepalive = True

        elif name == hdrs.UPGRADE:
            if 'websocket' in value.lower():
                self.websocket = True
            self.headers[name] = value

        elif name not in self.HOP_HEADERS:
            # ignore hop-by-hop headers
            self.headers.add(name, value)

    def add_headers(self, *headers):
        """Adds headers to a HTTP message."""
        for name, value in headers:
            self.add_header(name, value)

    def send_headers(self, _sep=': ', _end='\r\n'):
        """Writes headers to a stream. Constructs payload writer."""
        # Chunked response is only for HTTP/1.1 clients or newer
        # and there is no Content-Length header is set.
        # Do not use chunked responses when the response is guaranteed to
        # not have a response body (304, 204).
        assert not self.headers_sent, 'headers have been sent already'
        self.headers_sent = True

        if not self.chunked and self.autochunked():
            self.enable_chunking()

        if self.chunked:
            self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'

        self._add_default_headers()

        # status + headers
        headers = self.status_line + ''.join(
            [k + _sep + v + _end for k, v in self.headers.items()])
        headers = headers.encode('utf-8') + b'\r\n'

        self.buffer_data(headers)

    def _add_default_headers(self):
        # set the connection header
        connection = None
        if self.upgrade:
            connection = 'Upgrade'
        elif not self.closing if self.keepalive is None else self.keepalive:
            if self.version == HttpVersion10:
                connection = 'keep-alive'
        else:
            if self.version == HttpVersion11:
                connection = 'close'

        if connection is not None:
            self.headers[hdrs.CONNECTION] = connection
示例#60
0
    def append_json(self, obj, headers=None):
        """Helper to append JSON part."""
        if headers is None:
            headers = CIMultiDict()

        return self.append_payload(JsonPayload(obj, headers=headers))