Ejemplo n.º 1
0
 def get_headers(self, request, headers):
     # Returns a :class:`Header` obtained from combining
     # :attr:`headers` with *headers*. Can handle websocket requests.
     # TODO: this is a buf in CIMultiDict
     # d = self.headers.copy()
     d = CIMultiDict(self.headers.items())
     if headers:
         d.update(headers)
     return d
Ejemplo n.º 2
0
def headers_preserve_casing():
    try:
        from multidict import CIMultiDict
    except:
        return True

    d = CIMultiDict()
    d.update({'X-NewRelic-ID': 'value'})
    return 'X-NewRelic-ID' in dict(d.items())
Ejemplo n.º 3
0
 def get_headers(self, request, headers):
     # Returns a :class:`Header` obtained from combining
     # :attr:`headers` with *headers*. Can handle websocket requests.
     # TODO: this is a buf in CIMultiDict
     # d = self.headers.copy()
     d = CIMultiDict(self.headers.items())
     if headers:
         d.update(headers)
     return d
Ejemplo n.º 4
0
def headers_preserve_casing():
    try:
        from multidict import CIMultiDict
    except:
        return True

    d = CIMultiDict()
    d.update({"X-NewRelic-ID": "value"})
    return "X-NewRelic-ID" in dict(d.items())
Ejemplo n.º 5
0
 def _build_response(self,
                     url: 'Union[URL, str]',
                     method: str = hdrs.METH_GET,
                     status: int = 200,
                     body: str = '',
                     content_type: str = 'application/json',
                     payload: Dict = None,
                     headers: Dict = None,
                     response_class: 'ClientResponse' = None,
                     reason: Optional[str] = None) -> ClientResponse:
     if response_class is None:
         response_class = ClientResponse
     if payload is not None:
         body = json.dumps(payload)
     if not isinstance(body, bytes):
         body = str.encode(body)
     kwargs = {}
     if AIOHTTP_VERSION >= StrictVersion('3.1.0'):
         loop = Mock()
         loop.get_debug = Mock()
         loop.get_debug.return_value = True
         kwargs['request_info'] = Mock()
         kwargs['writer'] = Mock()
         kwargs['continue100'] = None
         kwargs['timer'] = TimerNoop()
         if AIOHTTP_VERSION < StrictVersion('3.3.0'):
             kwargs['auto_decompress'] = True
         kwargs['traces'] = []
         kwargs['loop'] = loop
         kwargs['session'] = None
     else:
         loop = None
     # We need to initialize headers manually
     _headers = CIMultiDict({hdrs.CONTENT_TYPE: content_type})
     if headers:
         _headers.update(headers)
     raw_headers = self._build_raw_headers(_headers)
     resp = response_class(method, url, **kwargs)
     if AIOHTTP_VERSION >= StrictVersion('3.3.0'):
         # Reified attributes
         resp._headers = _headers
         resp._raw_headers = raw_headers
     else:
         resp.headers = _headers
         resp.raw_headers = raw_headers
     resp.status = status
     resp.reason = reason
     resp.content = stream_reader_factory(loop)
     resp.content.feed_data(body)
     resp.content.feed_eof()
     return resp
Ejemplo n.º 6
0
 async def call(self,
                method: str,
                prefix: str,
                data: Any,
                headers: dict = {}):
     new_headers = CIMultiDict()
     new_headers.update(self.headers)
     new_headers.update(headers)
     data_binary = json.dumps(data).encode("utf-8")
     # TODO: this is too much code duplication but I cannot think of
     # a way outside macros that could abstract async with block
     # and sadly there are no macro in python
     if method == "get":
         async with self.session.get(self.base_url + "/" + prefix,
                                     headers=new_headers) as resp:
             await check_response(resp, self.logger)
             res = await resp.text()
         return self.check_result(res)
     elif method == "post":
         async with self.session.post(self.base_url + "/" + prefix,
                                      data=data_binary,
                                      headers=new_headers) as resp:
             await check_response(resp, self.logger)
             res = await resp.text()
         return self.check_result(res)
     elif method == "put":
         async with self.session.put(self.base_url + "/" + prefix,
                                     data=data_binary,
                                     headers=new_headers) as resp:
             await check_response(resp, self.logger)
             res = await resp.text()
         return self.check_result(res)
     elif method == "patch":
         async with self.session.patch(self.base_url + "/" + prefix,
                                       data=data_binary,
                                       headers=new_headers) as resp:
             await check_response(resp, self.logger)
             res = await resp.text()
         return self.check_result(res)
     elif method == "delete":
         async with self.session.delete(self.base_url + "/" + prefix,
                                        headers=new_headers) as resp:
             await check_response(resp, self.logger)
             res = await resp.text()
         return self.check_result(res)
Ejemplo n.º 7
0
 async def build_response(
         self, url: URL
 ) -> 'Union[ClientResponse, Exception]':
     if isinstance(self.exception, Exception):
         return self.exception
     kwargs = {}
     if AIOHTTP_VERSION >= StrictVersion('3.1.0'):
         loop = Mock()
         loop.get_debug = Mock()
         loop.get_debug.return_value = True
         kwargs['request_info'] = Mock()
         kwargs['writer'] = Mock()
         kwargs['continue100'] = None
         kwargs['timer'] = TimerNoop()
         if AIOHTTP_VERSION >= StrictVersion('3.3.0'):
             pass
         else:
             kwargs['auto_decompress'] = True
         kwargs['traces'] = []
         kwargs['loop'] = loop
         kwargs['session'] = None
     resp = self.response_class(self.method, url, **kwargs)
     # we need to initialize headers manually
     headers = CIMultiDict({hdrs.CONTENT_TYPE: self.content_type})
     if self.headers:
         headers.update(self.headers)
     raw_headers = self._build_raw_headers(headers)
     if AIOHTTP_VERSION >= StrictVersion('3.3.0'):
         # Reified attributes
         resp._headers = headers
         resp._raw_headers = raw_headers
     else:
         resp.headers = headers
         resp.raw_headers = raw_headers
     resp.status = self.status
     resp.reason = self.reason
     resp.content = stream_reader_factory()
     resp.content.feed_data(self.body)
     resp.content.feed_eof()
     return resp
Ejemplo n.º 8
0
class Payload(ABC):

    _default_content_type = "application/octet-stream"  # type: str
    _size = None  # type: Optional[int]

    def __init__(
        self,
        value: Any,
        headers: Optional[Union[_CIMultiDict, Dict[str, str],
                                Iterable[Tuple[str, str]]]] = None,
        content_type: Union[None, str, _SENTINEL] = sentinel,
        filename: Optional[str] = None,
        encoding: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        self._encoding = encoding
        self._filename = filename
        self._headers = CIMultiDict()  # type: _CIMultiDict
        self._value = value
        if content_type is not sentinel and content_type is not None:
            assert isinstance(content_type, str)
            self._headers[hdrs.CONTENT_TYPE] = content_type
        elif self._filename is not None:
            content_type = mimetypes.guess_type(self._filename)[0]
            if content_type is None:
                content_type = self._default_content_type
            self._headers[hdrs.CONTENT_TYPE] = content_type
        else:
            self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
        self._headers.update(headers or {})

    @property
    def size(self) -> Optional[int]:
        """Size of the payload."""
        return self._size

    @property
    def filename(self) -> Optional[str]:
        """Filename of the payload."""
        return self._filename

    @property
    def headers(self) -> _CIMultiDict:
        """Custom item headers"""
        return self._headers

    @property
    def _binary_headers(self) -> bytes:
        return (
            "".join([k + ": " + v + "\r\n"
                     for k, v in self.headers.items()]).encode("utf-8") +
            b"\r\n")

    @property
    def encoding(self) -> Optional[str]:
        """Payload encoding"""
        return self._encoding

    @property
    def content_type(self) -> str:
        """Content type"""
        return self._headers[hdrs.CONTENT_TYPE]

    def set_content_disposition(
        self,
        disptype: str,
        quote_fields: bool = True,
        _charset: str = "utf-8",
        **params: Any,
    ) -> None:
        """Sets ``Content-Disposition`` header."""
        self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header(
            disptype, quote_fields=quote_fields, _charset=_charset, **params)

    @abstractmethod
    async def write(self, writer: AbstractStreamWriter) -> None:
        """Write payload.
Ejemplo n.º 9
0
class ClientRequest:

    GET_METHODS = {hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS}
    POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    ALL_METHODS = GET_METHODS.union(POST_METHODS).union(
        {hdrs.METH_DELETE, hdrs.METH_TRACE})

    DEFAULT_HEADERS = {
        hdrs.ACCEPT: '*/*',
        hdrs.ACCEPT_ENCODING: 'gzip, deflate',
    }

    SERVER_SOFTWARE = HttpMessage.SERVER_SOFTWARE

    body = b''
    auth = None
    response = None
    response_class = None

    _writer = None  # async task for streaming data
    _continue = None  # waiter future for '100 Continue' response

    # N.B.
    # Adding __del__ method with self._writer closing doesn't make sense
    # because _writer is instance method, thus it keeps a reference to self.
    # Until writer has finished finalizer will not be called.

    def __init__(self, method, url, *,
                 params=None, headers=None, skip_auto_headers=frozenset(),
                 data=None, cookies=None,
                 auth=None, encoding='utf-8',
                 version=aiohttp.HttpVersion11, compress=None,
                 chunked=None, expect100=False,
                 loop=None, response_class=None,
                 proxy=None, proxy_auth=None,
                 timeout=5*60):

        if loop is None:
            loop = asyncio.get_event_loop()

        assert isinstance(url, URL), url
        assert isinstance(proxy, (URL, type(None))), proxy

        if params:
            q = MultiDict(url.query)
            url2 = url.with_query(params)
            q.extend(url2.query)
            url = url.with_query(q)
        self.url = url.with_fragment(None)
        self.method = method.upper()
        self.encoding = encoding
        self.chunked = chunked
        self.compress = compress
        self.loop = loop
        self.response_class = response_class or ClientResponse
        self._timeout = timeout

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        self.update_version(version)
        self.update_host(url)
        self.update_headers(headers)
        self.update_auto_headers(skip_auto_headers)
        self.update_cookies(cookies)
        self.update_content_encoding(data)
        self.update_auth(auth)
        self.update_proxy(proxy, proxy_auth)

        self.update_body_from_data(data, skip_auto_headers)
        self.update_transfer_encoding()
        self.update_expect_continue(expect100)

    @property
    def host(self):
        return self.url.host

    @property
    def port(self):
        return self.url.port

    def update_host(self, url):
        """Update destination host, port and connection type (ssl)."""
        # get host/port
        if not url.host:
            raise ValueError('Host could not be detected.')

        # basic auth info
        username, password = url.user, url.password
        if username:
            self.auth = helpers.BasicAuth(username, password or '')

        # Record entire netloc for usage in host header

        scheme = url.scheme
        self.ssl = scheme in ('https', 'wss')

    def update_version(self, version):
        """Convert request version to two elements tuple.

        parser HTTP version '1.1' => (1, 1)
        """
        if isinstance(version, str):
            v = [l.strip() for l in version.split('.', 1)]
            try:
                version = int(v[0]), int(v[1])
            except ValueError:
                raise ValueError(
                    'Can not parse http version number: {}'
                    .format(version)) from None
        self.version = version

    def update_headers(self, headers):
        """Update request headers."""
        self.headers = CIMultiDict()
        if headers:
            if isinstance(headers, dict):
                headers = headers.items()
            elif isinstance(headers, (MultiDictProxy, MultiDict)):
                headers = headers.items()

            for key, value in headers:
                self.headers.add(key, value)

    def update_auto_headers(self, skip_auto_headers):
        self.skip_auto_headers = skip_auto_headers
        used_headers = set(self.headers) | skip_auto_headers

        for hdr, val in self.DEFAULT_HEADERS.items():
            if hdr not in used_headers:
                self.headers.add(hdr, val)

        # add host
        if hdrs.HOST not in used_headers:
            netloc = self.url.host
            if not self.url.is_default_port():
                netloc += ':' + str(self.url.port)
            self.headers[hdrs.HOST] = netloc

        if hdrs.USER_AGENT not in used_headers:
            self.headers[hdrs.USER_AGENT] = self.SERVER_SOFTWARE

    def update_cookies(self, cookies):
        """Update request cookies header."""
        if not cookies:
            return

        c = http.cookies.SimpleCookie()
        if hdrs.COOKIE in self.headers:
            c.load(self.headers.get(hdrs.COOKIE, ''))
            del self.headers[hdrs.COOKIE]

        for name, value in cookies.items():
            if isinstance(value, http.cookies.Morsel):
                c[value.key] = value.value
            else:
                c[name] = value

        self.headers[hdrs.COOKIE] = c.output(header='', sep=';').strip()

    def update_content_encoding(self, data):
        """Set request content encoding."""
        if not data:
            return

        enc = self.headers.get(hdrs.CONTENT_ENCODING, '').lower()
        if enc:
            if self.compress is not False:
                self.compress = enc
                # enable chunked, no need to deal with length
                self.chunked = True
        elif self.compress:
            if not isinstance(self.compress, str):
                self.compress = 'deflate'
            self.headers[hdrs.CONTENT_ENCODING] = self.compress
            self.chunked = True  # enable chunked, no need to deal with length

    def update_auth(self, auth):
        """Set basic auth."""
        if auth is None:
            auth = self.auth
        if auth is None:
            return

        if not isinstance(auth, helpers.BasicAuth):
            raise TypeError('BasicAuth() tuple is required instead')

        self.headers[hdrs.AUTHORIZATION] = auth.encode()

    def update_body_from_data(self, data, skip_auto_headers):
        if not data:
            return

        if isinstance(data, str):
            data = data.encode(self.encoding)

        if isinstance(data, (bytes, bytearray)):
            self.body = data
            if (hdrs.CONTENT_TYPE not in self.headers and
                    hdrs.CONTENT_TYPE not in skip_auto_headers):
                self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'
            if hdrs.CONTENT_LENGTH not in self.headers and not self.chunked:
                self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

        elif isinstance(data, (asyncio.StreamReader, streams.StreamReader,
                               streams.DataQueue)):
            self.body = data

        elif asyncio.iscoroutine(data):
            self.body = data
            if (hdrs.CONTENT_LENGTH not in self.headers and
                    self.chunked is None):
                self.chunked = True

        elif isinstance(data, io.IOBase):
            assert not isinstance(data, io.StringIO), \
                'attempt to send text data instead of binary'
            self.body = data
            if not self.chunked and isinstance(data, io.BytesIO):
                # Not chunking if content-length can be determined
                size = len(data.getbuffer())
                self.headers[hdrs.CONTENT_LENGTH] = str(size)
                self.chunked = False
            elif (not self.chunked and
                  isinstance(data, (io.BufferedReader, io.BufferedRandom))):
                # Not chunking if content-length can be determined
                try:
                    size = os.fstat(data.fileno()).st_size - data.tell()
                    self.headers[hdrs.CONTENT_LENGTH] = str(size)
                    self.chunked = False
                except OSError:
                    # data.fileno() is not supported, e.g.
                    # io.BufferedReader(io.BytesIO(b'data'))
                    self.chunked = True
            else:
                self.chunked = True

            if hasattr(data, 'mode'):
                if data.mode == 'r':
                    raise ValueError('file {!r} should be open in binary mode'
                                     ''.format(data))
            if (hdrs.CONTENT_TYPE not in self.headers and
                hdrs.CONTENT_TYPE not in skip_auto_headers and
                    hasattr(data, 'name')):
                mime = mimetypes.guess_type(data.name)[0]
                mime = 'application/octet-stream' if mime is None else mime
                self.headers[hdrs.CONTENT_TYPE] = mime

        elif isinstance(data, MultipartWriter):
            self.body = data.serialize()
            self.headers.update(data.headers)
            self.chunked = self.chunked or 8192

        else:
            if not isinstance(data, helpers.FormData):
                data = helpers.FormData(data)

            self.body = data(self.encoding)

            if (hdrs.CONTENT_TYPE not in self.headers and
                    hdrs.CONTENT_TYPE not in skip_auto_headers):
                self.headers[hdrs.CONTENT_TYPE] = data.content_type

            if data.is_multipart:
                self.chunked = self.chunked or 8192
            else:
                if (hdrs.CONTENT_LENGTH not in self.headers and
                        not self.chunked):
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_transfer_encoding(self):
        """Analyze transfer-encoding header."""
        te = self.headers.get(hdrs.TRANSFER_ENCODING, '').lower()

        if self.chunked:
            if hdrs.CONTENT_LENGTH in self.headers:
                del self.headers[hdrs.CONTENT_LENGTH]
            if 'chunked' not in te:
                self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'

            self.chunked = self.chunked if type(self.chunked) is int else 8192
        else:
            if 'chunked' in te:
                self.chunked = 8192
            else:
                self.chunked = None
                if hdrs.CONTENT_LENGTH not in self.headers:
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_expect_continue(self, expect=False):
        if expect:
            self.headers[hdrs.EXPECT] = '100-continue'
        elif self.headers.get(hdrs.EXPECT, '').lower() == '100-continue':
            expect = True

        if expect:
            self._continue = helpers.create_future(self.loop)

    def update_proxy(self, proxy, proxy_auth):
        if proxy and not proxy.scheme == 'http':
            raise ValueError("Only http proxies are supported")
        if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
            raise ValueError("proxy_auth must be None or BasicAuth() tuple")
        self.proxy = proxy
        self.proxy_auth = proxy_auth

    @asyncio.coroutine
    def write_bytes(self, request, reader):
        """Support coroutines that yields bytes objects."""
        # 100 response
        if self._continue is not None:
            yield from self._continue

        try:
            if asyncio.iscoroutine(self.body):
                request.transport.set_tcp_nodelay(True)
                exc = None
                value = None
                stream = self.body

                while True:
                    try:
                        if exc is not None:
                            result = stream.throw(exc)
                        else:
                            result = stream.send(value)
                    except StopIteration as exc:
                        if isinstance(exc.value, bytes):
                            yield from request.write(exc.value, drain=True)
                        break
                    except:
                        self.response.close()
                        raise

                    if isinstance(result, asyncio.Future):
                        exc = None
                        value = None
                        try:
                            value = yield result
                        except Exception as err:
                            exc = err
                    elif isinstance(result, (bytes, bytearray)):
                        yield from request.write(result, drain=True)
                        value = None
                    else:
                        raise ValueError(
                            'Bytes object is expected, got: %s.' %
                            type(result))

            elif isinstance(self.body, (asyncio.StreamReader,
                                        streams.StreamReader)):
                request.transport.set_tcp_nodelay(True)
                chunk = yield from self.body.read(streams.DEFAULT_LIMIT)
                while chunk:
                    yield from request.write(chunk, drain=True)
                    chunk = yield from self.body.read(streams.DEFAULT_LIMIT)

            elif isinstance(self.body, streams.DataQueue):
                request.transport.set_tcp_nodelay(True)
                while True:
                    try:
                        chunk = yield from self.body.read()
                        if chunk is EOF_MARKER:
                            break
                        yield from request.write(chunk, drain=True)
                    except streams.EofStream:
                        break

            elif isinstance(self.body, io.IOBase):
                chunk = self.body.read(self.chunked)
                while chunk:
                    request.write(chunk)
                    chunk = self.body.read(self.chunked)
                request.transport.set_tcp_nodelay(True)

            else:
                if isinstance(self.body, (bytes, bytearray)):
                    self.body = (self.body,)

                for chunk in self.body:
                    request.write(chunk)
                request.transport.set_tcp_nodelay(True)

        except Exception as exc:
            new_exc = aiohttp.ClientRequestError(
                'Can not write request body for %s' % self.url)
            new_exc.__context__ = exc
            new_exc.__cause__ = exc
            reader.set_exception(new_exc)
        else:
            assert request.transport.tcp_nodelay
            try:
                ret = request.write_eof()
                # NB: in asyncio 3.4.1+ StreamWriter.drain() is coroutine
                # see bug #170
                if (asyncio.iscoroutine(ret) or
                        isinstance(ret, asyncio.Future)):
                    yield from ret
            except Exception as exc:
                new_exc = aiohttp.ClientRequestError(
                    'Can not write request body for %s' % self.url)
                new_exc.__context__ = exc
                new_exc.__cause__ = exc
                reader.set_exception(new_exc)

        self._writer = None

    def send(self, writer, reader):
        writer.set_tcp_cork(True)
        path = self.url.raw_path
        if self.url.raw_query_string:
            path += '?' + self.url.raw_query_string
        request = aiohttp.Request(writer, self.method, path,
                                  self.version)

        if self.compress:
            request.add_compression_filter(self.compress)

        if self.chunked is not None:
            request.enable_chunked_encoding()
            request.add_chunking_filter(self.chunked)

        # set default content-type
        if (self.method in self.POST_METHODS and
                hdrs.CONTENT_TYPE not in self.skip_auto_headers and
                hdrs.CONTENT_TYPE not in self.headers):
            self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'

        for k, value in self.headers.items():
            request.add_header(k, value)
        request.send_headers()

        self._writer = helpers.ensure_future(
            self.write_bytes(request, reader), loop=self.loop)

        self.response = self.response_class(
            self.method, self.url,
            writer=self._writer, continue100=self._continue,
            timeout=self._timeout)
        self.response._post_init(self.loop)
        return self.response

    @asyncio.coroutine
    def close(self):
        if self._writer is not None:
            try:
                yield from self._writer
            finally:
                self._writer = None

    def terminate(self):
        if self._writer is not None:
            if not self.loop.is_closed():
                self._writer.cancel()
            self._writer = None
Ejemplo n.º 10
0
class ClientRequest:

    GET_METHODS = {hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS}
    POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    ALL_METHODS = GET_METHODS.union(POST_METHODS).union(
        {hdrs.METH_DELETE, hdrs.METH_TRACE})

    DEFAULT_HEADERS = {
        hdrs.ACCEPT: '*/*',
        hdrs.ACCEPT_ENCODING: 'gzip, deflate',
    }

    SERVER_SOFTWARE = HttpMessage.SERVER_SOFTWARE

    body = b''
    auth = None
    response = None
    response_class = None

    _writer = None  # async task for streaming data
    _continue = None  # waiter future for '100 Continue' response

    # N.B.
    # Adding __del__ method with self._writer closing doesn't make sense
    # because _writer is instance method, thus it keeps a reference to self.
    # Until writer has finished finalizer will not be called.

    def __init__(self,
                 method,
                 url,
                 *,
                 params=None,
                 headers=None,
                 skip_auto_headers=frozenset(),
                 data=None,
                 cookies=None,
                 auth=None,
                 encoding='utf-8',
                 version=http.HttpVersion11,
                 compress=None,
                 chunked=None,
                 expect100=False,
                 loop=None,
                 response_class=None,
                 proxy=None,
                 proxy_auth=None,
                 timer=None):

        if loop is None:
            loop = asyncio.get_event_loop()

        assert isinstance(url, URL), url
        assert isinstance(proxy, (URL, type(None))), proxy

        if params:
            q = MultiDict(url.query)
            url2 = url.with_query(params)
            q.extend(url2.query)
            url = url.with_query(q)
        self.url = url.with_fragment(None)
        self.original_url = url
        self.method = method.upper()
        self.encoding = encoding
        self.chunked = chunked
        self.compress = compress
        self.loop = loop
        self.response_class = response_class or ClientResponse
        self._timer = timer if timer is not None else _TimeServiceTimeoutNoop()

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        self.update_version(version)
        self.update_host(url)
        self.update_headers(headers)
        self.update_auto_headers(skip_auto_headers)
        self.update_cookies(cookies)
        self.update_content_encoding(data)
        self.update_auth(auth)
        self.update_proxy(proxy, proxy_auth)

        self.update_body_from_data(data, skip_auto_headers)
        self.update_transfer_encoding()
        self.update_expect_continue(expect100)

    @property
    def host(self):
        return self.url.host

    @property
    def port(self):
        return self.url.port

    def update_host(self, url):
        """Update destination host, port and connection type (ssl)."""
        # get host/port
        if not url.host:
            raise ValueError('Host could not be detected.')

        # basic auth info
        username, password = url.user, url.password
        if username:
            self.auth = helpers.BasicAuth(username, password or '')

        # Record entire netloc for usage in host header

        scheme = url.scheme
        self.ssl = scheme in ('https', 'wss')

    def update_version(self, version):
        """Convert request version to two elements tuple.

        parser HTTP version '1.1' => (1, 1)
        """
        if isinstance(version, str):
            v = [l.strip() for l in version.split('.', 1)]
            try:
                version = int(v[0]), int(v[1])
            except ValueError:
                raise ValueError(
                    'Can not parse http version number: {}'.format(
                        version)) from None
        self.version = version

    def update_headers(self, headers):
        """Update request headers."""
        self.headers = CIMultiDict()
        if headers:
            if isinstance(headers, (dict, MultiDictProxy, MultiDict)):
                headers = headers.items()

            for key, value in headers:
                self.headers.add(key, value)

    def update_auto_headers(self, skip_auto_headers):
        self.skip_auto_headers = skip_auto_headers
        used_headers = set(self.headers) | skip_auto_headers

        for hdr, val in self.DEFAULT_HEADERS.items():
            if hdr not in used_headers:
                self.headers.add(hdr, val)

        # add host
        if hdrs.HOST not in used_headers:
            netloc = self.url.raw_host
            if not self.url.is_default_port():
                netloc += ':' + str(self.url.port)
            self.headers[hdrs.HOST] = netloc

        if hdrs.USER_AGENT not in used_headers:
            self.headers[hdrs.USER_AGENT] = self.SERVER_SOFTWARE

    def update_cookies(self, cookies):
        """Update request cookies header."""
        if not cookies:
            return

        c = SimpleCookie()
        if hdrs.COOKIE in self.headers:
            c.load(self.headers.get(hdrs.COOKIE, ''))
            del self.headers[hdrs.COOKIE]

        for name, value in cookies.items():
            if isinstance(value, Morsel):
                # Preserve coded_value
                mrsl_val = value.get(value.key, Morsel())
                mrsl_val.set(value.key, value.value, value.coded_value)
                c[name] = mrsl_val
            else:
                c[name] = value

        self.headers[hdrs.COOKIE] = c.output(header='', sep=';').strip()

    def update_content_encoding(self, data):
        """Set request content encoding."""
        if not data:
            return

        enc = self.headers.get(hdrs.CONTENT_ENCODING, '').lower()
        if enc:
            if self.compress is not False:
                self.compress = enc
                # enable chunked, no need to deal with length
                self.chunked = True
        elif self.compress:
            if not isinstance(self.compress, str):
                self.compress = 'deflate'
            self.headers[hdrs.CONTENT_ENCODING] = self.compress
            self.chunked = True  # enable chunked, no need to deal with length

    def update_auth(self, auth):
        """Set basic auth."""
        if auth is None:
            auth = self.auth
        if auth is None:
            return

        if not isinstance(auth, helpers.BasicAuth):
            raise TypeError('BasicAuth() tuple is required instead')

        self.headers[hdrs.AUTHORIZATION] = auth.encode()

    def update_body_from_data(self, data, skip_auto_headers):
        if not data:
            return

        if isinstance(data, str):
            data = data.encode(self.encoding)

        if isinstance(data, (bytes, bytearray)):
            self.body = data
            if (hdrs.CONTENT_TYPE not in self.headers
                    and hdrs.CONTENT_TYPE not in skip_auto_headers):
                self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'
            if hdrs.CONTENT_LENGTH not in self.headers and not self.chunked:
                self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

        elif isinstance(
                data,
            (asyncio.StreamReader, streams.StreamReader, streams.DataQueue)):
            self.body = data

        elif asyncio.iscoroutine(data):
            self.body = data
            if (hdrs.CONTENT_LENGTH not in self.headers
                    and self.chunked is None):
                self.chunked = True

        elif isinstance(data, io.IOBase):
            assert not isinstance(data, io.StringIO), \
                'attempt to send text data instead of binary'
            self.body = data
            if not self.chunked and isinstance(data, io.BytesIO):
                # Not chunking if content-length can be determined
                size = len(data.getbuffer())
                self.headers[hdrs.CONTENT_LENGTH] = str(size)
                self.chunked = False
            elif (not self.chunked
                  and isinstance(data,
                                 (io.BufferedReader, io.BufferedRandom))):
                # Not chunking if content-length can be determined
                try:
                    size = os.fstat(data.fileno()).st_size - data.tell()
                    self.headers[hdrs.CONTENT_LENGTH] = str(size)
                    self.chunked = False
                except OSError:
                    # data.fileno() is not supported, e.g.
                    # io.BufferedReader(io.BytesIO(b'data'))
                    self.chunked = True
            else:
                self.chunked = True

            if hasattr(data, 'mode'):
                if data.mode == 'r':
                    raise ValueError('file {!r} should be open in binary mode'
                                     ''.format(data))
            if (hdrs.CONTENT_TYPE not in self.headers
                    and hdrs.CONTENT_TYPE not in skip_auto_headers
                    and hasattr(data, 'name')):
                mime = mimetypes.guess_type(data.name)[0]
                mime = 'application/octet-stream' if mime is None else mime
                self.headers[hdrs.CONTENT_TYPE] = mime

        elif isinstance(data, MultipartWriter):
            self.body = data.serialize()
            self.headers.update(data.headers)
            self.chunked = True

        else:
            if not isinstance(data, helpers.FormData):
                data = helpers.FormData(data)

            self.body = data(self.encoding)

            if (hdrs.CONTENT_TYPE not in self.headers
                    and hdrs.CONTENT_TYPE not in skip_auto_headers):
                self.headers[hdrs.CONTENT_TYPE] = data.content_type

            if data.is_multipart:
                self.chunked = True
            else:
                if (hdrs.CONTENT_LENGTH not in self.headers
                        and not self.chunked):
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_transfer_encoding(self):
        """Analyze transfer-encoding header."""
        te = self.headers.get(hdrs.TRANSFER_ENCODING, '').lower()

        if self.chunked:
            if hdrs.CONTENT_LENGTH in self.headers:
                del self.headers[hdrs.CONTENT_LENGTH]
            if 'chunked' not in te:
                self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'

        else:
            if 'chunked' in te:
                self.chunked = True
            else:
                self.chunked = None
                if hdrs.CONTENT_LENGTH not in self.headers:
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_expect_continue(self, expect=False):
        if expect:
            self.headers[hdrs.EXPECT] = '100-continue'
        elif self.headers.get(hdrs.EXPECT, '').lower() == '100-continue':
            expect = True

        if expect:
            self._continue = helpers.create_future(self.loop)

    def update_proxy(self, proxy, proxy_auth):
        if proxy and not proxy.scheme == 'http':
            raise ValueError("Only http proxies are supported")
        if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
            raise ValueError("proxy_auth must be None or BasicAuth() tuple")
        self.proxy = proxy
        self.proxy_auth = proxy_auth

    @asyncio.coroutine
    def write_bytes(self, request, conn):
        """Support coroutines that yields bytes objects."""
        # 100 response
        if self._continue is not None:
            yield from request.drain()
            yield from self._continue

        try:
            if asyncio.iscoroutine(self.body):
                exc = None
                value = None
                stream = self.body

                while True:
                    try:
                        if exc is not None:
                            result = stream.throw(exc)
                        else:
                            result = stream.send(value)
                    except StopIteration as exc:
                        if isinstance(exc.value, bytes):
                            yield from request.write(exc.value)
                        break
                    except:
                        self.response.close()
                        raise

                    if isinstance(result, asyncio.Future):
                        exc = None
                        value = None
                        try:
                            value = yield result
                        except Exception as err:
                            exc = err
                    elif isinstance(result, (bytes, bytearray)):
                        yield from request.write(result)
                        value = None
                    else:
                        raise ValueError('Bytes object is expected, got: %s.' %
                                         type(result))

            elif isinstance(self.body,
                            (asyncio.StreamReader, streams.StreamReader)):
                chunk = yield from self.body.read(streams.DEFAULT_LIMIT)
                while chunk:
                    yield from request.write(chunk, drain=True)
                    chunk = yield from self.body.read(streams.DEFAULT_LIMIT)

            elif isinstance(self.body, streams.DataQueue):
                while True:
                    try:
                        chunk = yield from self.body.read()
                        if not chunk:
                            break
                        yield from request.write(chunk)
                    except streams.EofStream:
                        break

            elif isinstance(self.body, io.IOBase):
                chunk = self.body.read(streams.DEFAULT_LIMIT)
                while chunk:
                    request.write(chunk)
                    chunk = self.body.read(self.chunked)
            else:
                if isinstance(self.body, (bytes, bytearray)):
                    self.body = (self.body, )

                for chunk in self.body:
                    request.write(chunk)

        except Exception as exc:
            new_exc = aiohttp.ClientRequestError(
                'Can not write request body for %s' % self.url)
            new_exc.__context__ = exc
            new_exc.__cause__ = exc
            conn.protocol.set_exception(new_exc)
        else:
            try:
                yield from request.write_eof()
            except Exception as exc:
                new_exc = aiohttp.ClientRequestError(
                    'Can not write request body for %s' % self.url)
                new_exc.__context__ = exc
                new_exc.__cause__ = exc
                conn.protocol.set_exception(new_exc)

        self._writer = None

    def send(self, conn):
        # Specify request target:
        # - CONNECT request must send authority form URI
        # - not CONNECT proxy must send absolute form URI
        # - most common is origin form URI
        if self.method == hdrs.METH_CONNECT:
            path = '{}:{}'.format(self.url.raw_host, self.url.port)
        elif self.proxy and not self.ssl:
            path = str(self.url)
        else:
            path = self.url.raw_path
            if self.url.raw_query_string:
                path += '?' + self.url.raw_query_string

        request = http.Request(conn.writer,
                               self.method,
                               path,
                               self.version,
                               loop=self.loop)

        if self.compress:
            request.enable_compression(self.compress)

        if self.chunked is not None:
            request.enable_chunking()

        # set default content-type
        if (self.method in self.POST_METHODS
                and hdrs.CONTENT_TYPE not in self.skip_auto_headers
                and hdrs.CONTENT_TYPE not in self.headers):
            self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'

        for k, value in self.headers.items():
            request.add_header(k, value)
        request.send_headers()

        self._writer = helpers.ensure_future(self.write_bytes(request, conn),
                                             loop=self.loop)

        self.response = self.response_class(self.method,
                                            self.original_url,
                                            writer=self._writer,
                                            continue100=self._continue,
                                            timer=self._timer)

        self.response._post_init(self.loop)
        return self.response

    @asyncio.coroutine
    def close(self):
        if self._writer is not None:
            try:
                yield from self._writer
            finally:
                self._writer = None

    def terminate(self):
        if self._writer is not None:
            if not self.loop.is_closed():
                self._writer.cancel()
            self._writer = None
Ejemplo n.º 11
0
class ClientRequest:

    GET_METHODS = {hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS}
    POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    ALL_METHODS = GET_METHODS.union(POST_METHODS).union(
        {hdrs.METH_DELETE, hdrs.METH_TRACE})

    DEFAULT_HEADERS = {
        hdrs.ACCEPT: '*/*',
        hdrs.ACCEPT_ENCODING: 'gzip, deflate',
    }

    SERVER_SOFTWARE = HttpMessage.SERVER_SOFTWARE

    body = b''
    auth = None
    response = None
    response_class = None

    _writer = None  # async task for streaming data
    _continue = None  # waiter future for '100 Continue' response

    # N.B.
    # Adding __del__ method with self._writer closing doesn't make sense
    # because _writer is instance method, thus it keeps a reference to self.
    # Until writer has finished finalizer will not be called.

    def __init__(self,
                 method,
                 url,
                 *,
                 params=None,
                 headers=None,
                 skip_auto_headers=frozenset(),
                 data=None,
                 cookies=None,
                 auth=None,
                 encoding='utf-8',
                 version=aiohttp.HttpVersion11,
                 compress=None,
                 chunked=None,
                 expect100=False,
                 loop=None,
                 response_class=None,
                 proxy=None,
                 proxy_auth=None,
                 timeout=5 * 60):

        if loop is None:
            loop = asyncio.get_event_loop()

        self.url = url
        self.method = method.upper()
        self.encoding = encoding
        self.chunked = chunked
        self.compress = compress
        self.loop = loop
        self.response_class = response_class or ClientResponse
        self._timeout = timeout

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        self.update_version(version)
        self.update_host(url)
        self.update_path(params)
        self.update_headers(headers)
        self.update_auto_headers(skip_auto_headers)
        self.update_cookies(cookies)
        self.update_content_encoding(data)
        self.update_auth(auth)
        self.update_proxy(proxy, proxy_auth)

        self.update_body_from_data(data, skip_auto_headers)
        self.update_transfer_encoding()
        self.update_expect_continue(expect100)

    def update_host(self, url):
        """Update destination host, port and connection type (ssl)."""
        url_parsed = urllib.parse.urlsplit(url)

        # check for network location part
        netloc = url_parsed.netloc
        if not netloc:
            raise ValueError('Host could not be detected.')

        # get host/port
        host = url_parsed.hostname
        if not host:
            raise ValueError('Host could not be detected.')

        try:
            port = url_parsed.port
        except ValueError:
            raise ValueError('Port number could not be converted.') from None

        # check domain idna encoding
        try:
            host = host.encode('idna').decode('utf-8')
            netloc = self.make_netloc(host, url_parsed.port)
        except UnicodeError:
            raise ValueError('URL has an invalid label.')

        # basic auth info
        username, password = url_parsed.username, url_parsed.password
        if username:
            self.auth = helpers.BasicAuth(username, password or '')

        # Record entire netloc for usage in host header
        self.netloc = netloc

        scheme = url_parsed.scheme
        self.ssl = scheme in ('https', 'wss')

        # set port number if it isn't already set
        if not port:
            if self.ssl:
                port = HTTPS_PORT
            else:
                port = HTTP_PORT

        self.host, self.port, self.scheme = host, port, scheme

    def make_netloc(self, host, port):
        ret = host
        if port:
            ret = ret + ':' + str(port)
        return ret

    def update_version(self, version):
        """Convert request version to two elements tuple.

        parser HTTP version '1.1' => (1, 1)
        """
        if isinstance(version, str):
            v = [l.strip() for l in version.split('.', 1)]
            try:
                version = int(v[0]), int(v[1])
            except ValueError:
                raise ValueError(
                    'Can not parse http version number: {}'.format(
                        version)) from None
        self.version = version

    def update_path(self, params):
        """Build path."""
        # extract path
        scheme, netloc, path, query, fragment = urllib.parse.urlsplit(self.url)
        if not path:
            path = '/'

        if isinstance(params, collections.Mapping):
            params = list(params.items())

        if params:
            if not isinstance(params, str):
                params = urllib.parse.urlencode(params)
            if query:
                query = '%s&%s' % (query, params)
            else:
                query = params

        self.path = urllib.parse.urlunsplit(
            ('', '', helpers.requote_uri(path), query, ''))
        self.url = urllib.parse.urlunsplit(
            (scheme, netloc, self.path, '', fragment))

    def update_headers(self, headers):
        """Update request headers."""
        self.headers = CIMultiDict()
        if headers:
            if isinstance(headers, dict):
                headers = headers.items()
            elif isinstance(headers, (MultiDictProxy, MultiDict)):
                headers = headers.items()

            for key, value in headers:
                self.headers.add(key, value)

    def update_auto_headers(self, skip_auto_headers):
        self.skip_auto_headers = skip_auto_headers
        used_headers = set(self.headers) | skip_auto_headers

        for hdr, val in self.DEFAULT_HEADERS.items():
            if hdr not in used_headers:
                self.headers.add(hdr, val)

        # add host
        if hdrs.HOST not in used_headers:
            self.headers[hdrs.HOST] = self.netloc

        if hdrs.USER_AGENT not in used_headers:
            self.headers[hdrs.USER_AGENT] = self.SERVER_SOFTWARE

    def update_cookies(self, cookies):
        """Update request cookies header."""
        if not cookies:
            return

        c = http.cookies.SimpleCookie()
        if hdrs.COOKIE in self.headers:
            c.load(self.headers.get(hdrs.COOKIE, ''))
            del self.headers[hdrs.COOKIE]

        if isinstance(cookies, dict):
            cookies = cookies.items()

        for name, value in cookies:
            if isinstance(value, http.cookies.Morsel):
                c[value.key] = value.value
            else:
                c[name] = value

        self.headers[hdrs.COOKIE] = c.output(header='', sep=';').strip()

    def update_content_encoding(self, data):
        """Set request content encoding."""
        if not data:
            return

        enc = self.headers.get(hdrs.CONTENT_ENCODING, '').lower()
        if enc:
            if self.compress is not False:
                self.compress = enc
                # enable chunked, no need to deal with length
                self.chunked = True
        elif self.compress:
            if not isinstance(self.compress, str):
                self.compress = 'deflate'
            self.headers[hdrs.CONTENT_ENCODING] = self.compress
            self.chunked = True  # enable chunked, no need to deal with length

    def update_auth(self, auth):
        """Set basic auth."""
        if auth is None:
            auth = self.auth
        if auth is None:
            return

        if not isinstance(auth, helpers.BasicAuth):
            raise TypeError('BasicAuth() tuple is required instead')

        self.headers[hdrs.AUTHORIZATION] = auth.encode()

    def update_body_from_data(self, data, skip_auto_headers):
        if not data:
            return

        if isinstance(data, str):
            data = data.encode(self.encoding)

        if isinstance(data, (bytes, bytearray)):
            self.body = data
            if (hdrs.CONTENT_TYPE not in self.headers
                    and hdrs.CONTENT_TYPE not in skip_auto_headers):
                self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'
            if hdrs.CONTENT_LENGTH not in self.headers and not self.chunked:
                self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

        elif isinstance(
                data,
            (asyncio.StreamReader, streams.StreamReader, streams.DataQueue)):
            self.body = data

        elif asyncio.iscoroutine(data):
            self.body = data
            if (hdrs.CONTENT_LENGTH not in self.headers
                    and self.chunked is None):
                self.chunked = True

        elif isinstance(data, io.IOBase):
            assert not isinstance(data, io.StringIO), \
                'attempt to send text data instead of binary'
            self.body = data
            if not self.chunked and isinstance(data, io.BytesIO):
                # Not chunking if content-length can be determined
                size = len(data.getbuffer())
                self.headers[hdrs.CONTENT_LENGTH] = str(size)
                self.chunked = False
            elif not self.chunked and isinstance(data, io.BufferedReader):
                # Not chunking if content-length can be determined
                try:
                    size = os.fstat(data.fileno()).st_size - data.tell()
                    self.headers[hdrs.CONTENT_LENGTH] = str(size)
                    self.chunked = False
                except OSError:
                    # data.fileno() is not supported, e.g.
                    # io.BufferedReader(io.BytesIO(b'data'))
                    self.chunked = True
            else:
                self.chunked = True

            if hasattr(data, 'mode'):
                if data.mode == 'r':
                    raise ValueError('file {!r} should be open in binary mode'
                                     ''.format(data))
            if (hdrs.CONTENT_TYPE not in self.headers
                    and hdrs.CONTENT_TYPE not in skip_auto_headers
                    and hasattr(data, 'name')):
                mime = mimetypes.guess_type(data.name)[0]
                mime = 'application/octet-stream' if mime is None else mime
                self.headers[hdrs.CONTENT_TYPE] = mime

        elif isinstance(data, MultipartWriter):
            self.body = data.serialize()
            self.headers.update(data.headers)
            self.chunked = self.chunked or 8192

        else:
            if not isinstance(data, helpers.FormData):
                data = helpers.FormData(data)

            self.body = data(self.encoding)

            if (hdrs.CONTENT_TYPE not in self.headers
                    and hdrs.CONTENT_TYPE not in skip_auto_headers):
                self.headers[hdrs.CONTENT_TYPE] = data.content_type

            if data.is_multipart:
                self.chunked = self.chunked or 8192
            else:
                if (hdrs.CONTENT_LENGTH not in self.headers
                        and not self.chunked):
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_transfer_encoding(self):
        """Analyze transfer-encoding header."""
        te = self.headers.get(hdrs.TRANSFER_ENCODING, '').lower()

        if self.chunked:
            if hdrs.CONTENT_LENGTH in self.headers:
                del self.headers[hdrs.CONTENT_LENGTH]
            if 'chunked' not in te:
                self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'

            self.chunked = self.chunked if type(self.chunked) is int else 8192
        else:
            if 'chunked' in te:
                self.chunked = 8192
            else:
                self.chunked = None
                if hdrs.CONTENT_LENGTH not in self.headers:
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_expect_continue(self, expect=False):
        if expect:
            self.headers[hdrs.EXPECT] = '100-continue'
        elif self.headers.get(hdrs.EXPECT, '').lower() == '100-continue':
            expect = True

        if expect:
            self._continue = helpers.create_future(self.loop)

    def update_proxy(self, proxy, proxy_auth):
        if proxy and not proxy.startswith('http://'):
            raise ValueError("Only http proxies are supported")
        if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
            raise ValueError("proxy_auth must be None or BasicAuth() tuple")
        self.proxy = proxy
        self.proxy_auth = proxy_auth

    @asyncio.coroutine
    def write_bytes(self, request, reader):
        """Support coroutines that yields bytes objects."""
        # 100 response
        if self._continue is not None:
            yield from self._continue

        try:
            if asyncio.iscoroutine(self.body):
                request.transport.set_tcp_nodelay(True)
                exc = None
                value = None
                stream = self.body

                while True:
                    try:
                        if exc is not None:
                            result = stream.throw(exc)
                        else:
                            result = stream.send(value)
                    except StopIteration as exc:
                        if isinstance(exc.value, bytes):
                            yield from request.write(exc.value, drain=True)
                        break
                    except:
                        self.response.close()
                        raise

                    if isinstance(result, asyncio.Future):
                        exc = None
                        value = None
                        try:
                            value = yield result
                        except Exception as err:
                            exc = err
                    elif isinstance(result, (bytes, bytearray)):
                        yield from request.write(result, drain=True)
                        value = None
                    else:
                        raise ValueError('Bytes object is expected, got: %s.' %
                                         type(result))

            elif isinstance(self.body,
                            (asyncio.StreamReader, streams.StreamReader)):
                request.transport.set_tcp_nodelay(True)
                chunk = yield from self.body.read(streams.DEFAULT_LIMIT)
                while chunk:
                    yield from request.write(chunk, drain=True)
                    chunk = yield from self.body.read(streams.DEFAULT_LIMIT)

            elif isinstance(self.body, streams.DataQueue):
                request.transport.set_tcp_nodelay(True)
                while True:
                    try:
                        chunk = yield from self.body.read()
                        if chunk is EOF_MARKER:
                            break
                        yield from request.write(chunk, drain=True)
                    except streams.EofStream:
                        break

            elif isinstance(self.body, io.IOBase):
                chunk = self.body.read(self.chunked)
                while chunk:
                    request.write(chunk)
                    chunk = self.body.read(self.chunked)
                request.transport.set_tcp_nodelay(True)

            else:
                if isinstance(self.body, (bytes, bytearray)):
                    self.body = (self.body, )

                for chunk in self.body:
                    request.write(chunk)
                request.transport.set_tcp_nodelay(True)

        except Exception as exc:
            new_exc = aiohttp.ClientRequestError(
                'Can not write request body for %s' % self.url)
            new_exc.__context__ = exc
            new_exc.__cause__ = exc
            reader.set_exception(new_exc)
        else:
            assert request.transport.tcp_nodelay
            try:
                ret = request.write_eof()
                # NB: in asyncio 3.4.1+ StreamWriter.drain() is coroutine
                # see bug #170
                if (asyncio.iscoroutine(ret)
                        or isinstance(ret, asyncio.Future)):
                    yield from ret
            except Exception as exc:
                new_exc = aiohttp.ClientRequestError(
                    'Can not write request body for %s' % self.url)
                new_exc.__context__ = exc
                new_exc.__cause__ = exc
                reader.set_exception(new_exc)

        self._writer = None

    def send(self, writer, reader):
        writer.set_tcp_cork(True)
        request = aiohttp.Request(writer, self.method, self.path, self.version)

        if self.compress:
            request.add_compression_filter(self.compress)

        if self.chunked is not None:
            request.enable_chunked_encoding()
            request.add_chunking_filter(self.chunked)

        # set default content-type
        if (self.method in self.POST_METHODS
                and hdrs.CONTENT_TYPE not in self.skip_auto_headers
                and hdrs.CONTENT_TYPE not in self.headers):
            self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'

        for k, value in self.headers.items():
            request.add_header(k, value)
        request.send_headers()

        self._writer = helpers.ensure_future(self.write_bytes(request, reader),
                                             loop=self.loop)

        self.response = self.response_class(self.method,
                                            self.url,
                                            self.host,
                                            writer=self._writer,
                                            continue100=self._continue,
                                            timeout=self._timeout)
        self.response._post_init(self.loop)
        return self.response

    @asyncio.coroutine
    def close(self):
        if self._writer is not None:
            try:
                yield from self._writer
            finally:
                self._writer = None

    def terminate(self):
        if self._writer is not None:
            if not self.loop.is_closed():
                self._writer.cancel()
            self._writer = None
Ejemplo n.º 12
0
class Payload(ABC):

    _default_content_type = 'application/octet-stream'  # type: str
    _size = None  # type: Optional[int]

    def __init__(self,
                 value: Any,
                 headers: Optional[
                     Union[
                         _CIMultiDict,
                         Dict[str, str],
                         Iterable[Tuple[str, str]]
                     ]
                 ] = None,
                 content_type: Optional[str]=sentinel,
                 filename: Optional[str]=None,
                 encoding: Optional[str]=None,
                 **kwargs: Any) -> None:
        self._encoding = encoding
        self._filename = filename
        self._headers = CIMultiDict()  # type: _CIMultiDict
        self._value = value
        if content_type is not sentinel and content_type is not None:
            self._headers[hdrs.CONTENT_TYPE] = content_type
        elif self._filename is not None:
            content_type = mimetypes.guess_type(self._filename)[0]
            if content_type is None:
                content_type = self._default_content_type
            self._headers[hdrs.CONTENT_TYPE] = content_type
        else:
            self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
        self._headers.update(headers or {})

    @property
    def size(self) -> Optional[int]:
        """Size of the payload."""
        return self._size

    @property
    def filename(self) -> Optional[str]:
        """Filename of the payload."""
        return self._filename

    @property
    def headers(self) -> _CIMultiDict:
        """Custom item headers"""
        return self._headers

    @property
    def _binary_headers(self) -> bytes:
        return ''.join(
            [k + ': ' + v + '\r\n' for k, v in self.headers.items()]
        ).encode('utf-8') + b'\r\n'

    @property
    def encoding(self) -> Optional[str]:
        """Payload encoding"""
        return self._encoding

    @property
    def content_type(self) -> str:
        """Content type"""
        return self._headers[hdrs.CONTENT_TYPE]

    def set_content_disposition(self,
                                disptype: str,
                                quote_fields: bool=True,
                                **params: Any) -> None:
        """Sets ``Content-Disposition`` header."""
        self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header(
            disptype, quote_fields=quote_fields, **params)

    @abstractmethod
    async def write(self, writer: AbstractStreamWriter) -> None:
        """Write payload.
Ejemplo n.º 13
0
class Request:
    '''
    The API request object.
    '''

    __slots__ = (
        'config',
        'session',
        'method',
        'path',
        'date',
        'headers',
        'params',
        'content_type',
        '_content',
        '_attached_files',
        'reporthook',
    )

    _allowed_methods = frozenset(
        ['GET', 'HEAD', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])

    def __init__(self,
                 session: BaseSession,
                 method: str = 'GET',
                 path: str = None,
                 content: RequestContent = None,
                 *,
                 content_type: str = None,
                 params: Mapping[str, str] = None,
                 reporthook: Callable = None) -> None:
        '''
        Initialize an API request.

        :param BaseSession session: The session where this request is executed on.

        :param str path: The query path. When performing requests, the version number
                         prefix will be automatically perpended if required.

        :param RequestContent content: The API query body which will be encoded as
                                       JSON.

        :param str content_type: Explicitly set the content type.  See also
                                 :func:`Request.set_content`.
        '''
        self.session = session
        self.config = session.config
        self.method = method
        if path.startswith('/'):
            path = path[1:]
        self.path = path
        self.params = params
        self.date = None
        self.headers = CIMultiDict([
            ('User-Agent', self.config.user_agent),
            ('X-BackendAI-Domain', self.config.domain),
            ('X-BackendAI-Version', self.config.version),
        ])
        self._attached_files = None
        self.set_content(content, content_type=content_type)
        self.reporthook = reporthook

    @property
    def content(self) -> RequestContent:
        '''
        Retrieves the content in the original form.
        Private codes should NOT use this as it incurs duplicate
        encoding/decoding.
        '''
        return self._content

    def set_content(self, value: RequestContent, *, content_type: str = None):
        '''
        Sets the content of the request.
        '''
        assert self._attached_files is None, \
               'cannot set content because you already attached files.'
        guessed_content_type = 'application/octet-stream'
        if value is None:
            guessed_content_type = 'text/plain'
            self._content = b''
        elif isinstance(value, str):
            guessed_content_type = 'text/plain'
            self._content = value.encode('utf-8')
        else:
            guessed_content_type = 'application/octet-stream'
            self._content = value
        self.content_type = (content_type if content_type is not None else
                             guessed_content_type)

    def set_json(self, value: object):
        '''
        A shortcut for set_content() with JSON objects.
        '''
        self.set_content(modjson.dumps(value, cls=ExtendedJSONEncoder),
                         content_type='application/json')

    def attach_files(self, files: Sequence[AttachedFile]):
        '''
        Attach a list of files represented as AttachedFile.
        '''
        assert not self._content, 'content must be empty to attach files.'
        self.content_type = 'multipart/form-data'
        self._attached_files = files

    def _sign(self, rel_url, access_key=None, secret_key=None, hash_type=None):
        '''
        Calculates the signature of the given request and adds the
        Authorization HTTP header.
        It should be called at the very end of request preparation and before
        sending the request to the server.
        '''
        if access_key is None:
            access_key = self.config.access_key
        if secret_key is None:
            secret_key = self.config.secret_key
        if hash_type is None:
            hash_type = self.config.hash_type
        if self.config.endpoint_type == 'api':
            hdrs, _ = generate_signature(self.method, self.config.version,
                                         self.config.endpoint, self.date,
                                         str(rel_url), self.content_type,
                                         self._content, access_key, secret_key,
                                         hash_type)
            self.headers.update(hdrs)
        elif self.config.endpoint_type == 'session':
            local_state_path = Path(
                appdirs.user_state_dir('backend.ai', 'Lablup'))
            try:
                self.session.aiohttp_session.cookie_jar.load(local_state_path /
                                                             'cookie.dat')
            except (IOError, PermissionError):
                pass
        else:
            raise ValueError('unsupported endpoint type')

    def _pack_content(self):
        if self._attached_files is not None:
            data = aiohttp.FormData()
            for f in self._attached_files:
                data.add_field('src',
                               f.stream,
                               filename=f.filename,
                               content_type=f.content_type)
            assert data.is_multipart, 'Failed to pack files as multipart.'
            # Let aiohttp fill up the content-type header including
            # multipart boundaries.
            self.headers.pop('Content-Type', None)
            return data
        else:
            return self._content

    def _build_url(self):
        base_url = self.config.endpoint.path.rstrip('/')
        query_path = self.path.lstrip('/') if len(self.path) > 0 else ''
        if self.config.endpoint_type == 'session':
            if not query_path.startswith('server'):
                query_path = 'func/{0}'.format(query_path)
        path = '{0}/{1}'.format(base_url, query_path)
        url = self.config.endpoint.with_path(path)
        if self.params:
            url = url.with_query(self.params)
        return url

    # TODO: attach rate-limit information

    def fetch(self, **kwargs) -> 'FetchContextManager':
        '''
        Sends the request to the server and reads the response.

        You may use this method either with plain synchronous Session or
        AsyncSession.  Both the followings patterns are valid:

        .. code-block:: python3

          from ai.backend.client.request import Request
          from ai.backend.client.session import Session

          with Session() as sess:
            rqst = Request(sess, 'GET', ...)
            with rqst.fetch() as resp:
              print(resp.text())

        .. code-block:: python3

          from ai.backend.client.request import Request
          from ai.backend.client.session import AsyncSession

          async with AsyncSession() as sess:
            rqst = Request(sess, 'GET', ...)
            async with rqst.fetch() as resp:
              print(await resp.text())
        '''
        assert self.method in self._allowed_methods, \
               'Disallowed HTTP method: {}'.format(self.method)
        self.date = datetime.now(tzutc())
        self.headers['Date'] = self.date.isoformat()
        if self.content_type is not None and 'Content-Type' not in self.headers:
            self.headers['Content-Type'] = self.content_type
        force_anonymous = kwargs.pop('anonymous', False)

        def _rqst_ctx_builder():
            timeout_config = aiohttp.ClientTimeout(
                total=None,
                connect=None,
                sock_connect=self.config.connection_timeout,
                sock_read=self.config.read_timeout,
            )
            full_url = self._build_url()
            if not self.config.is_anonymous and not force_anonymous:
                self._sign(full_url.relative())
            return self.session.aiohttp_session.request(
                self.method,
                str(full_url),
                data=self._pack_content(),
                timeout=timeout_config,
                headers=self.headers)

        return FetchContextManager(self.session, _rqst_ctx_builder, **kwargs)

    def connect_websocket(self, **kwargs) -> 'WebSocketContextManager':
        '''
        Creates a WebSocket connection.

        .. warning::

          This method only works with
          :class:`~ai.backend.client.session.AsyncSession`.
        '''
        assert isinstance(self.session, AsyncSession), \
               'Cannot use websockets with sessions in the synchronous mode'
        assert self.method == 'GET', 'Invalid websocket method'
        self.date = datetime.now(tzutc())
        self.headers['Date'] = self.date.isoformat()
        # websocket is always a "binary" stream.
        self.content_type = 'application/octet-stream'

        def _ws_ctx_builder():
            full_url = self._build_url()
            if not self.config.is_anonymous:
                self._sign(full_url.relative())
            return self.session.aiohttp_session.ws_connect(
                str(full_url),
                autoping=True,
                heartbeat=30.0,
                headers=self.headers)

        return WebSocketContextManager(self.session, _ws_ctx_builder, **kwargs)

    def connect_events(self, **kwargs) -> 'SSEContextManager':
        '''
        Creates a Server-Sent Events connection.

        .. warning::

          This method only works with
          :class:`~ai.backend.client.session.AsyncSession`.
        '''
        assert isinstance(self.session, AsyncSession), \
               'Cannot use event streams with sessions in the synchronous mode'
        assert self.method == 'GET', 'Invalid event stream method'
        self.date = datetime.now(tzutc())
        self.headers['Date'] = self.date.isoformat()
        self.content_type = 'application/octet-stream'

        def _rqst_ctx_builder():
            timeout_config = aiohttp.ClientTimeout(
                total=None,
                connect=None,
                sock_connect=self.config.connection_timeout,
                sock_read=self.config.read_timeout,
            )
            full_url = self._build_url()
            if not self.config.is_anonymous:
                self._sign(full_url.relative())
            return self.session.aiohttp_session.request(self.method,
                                                        str(full_url),
                                                        timeout=timeout_config,
                                                        headers=self.headers)

        return SSEContextManager(self.session, _rqst_ctx_builder, **kwargs)
Ejemplo n.º 14
0
 async def prepare_request_params(self, endpoint_desc, session, request_params):
     headers = CIMultiDict()
     headers.update(self.default_headers)
     headers.update(endpoint_desc.get('headers', {}))
     headers.update(request_params.get('headers', {}))
     request_params['headers'] = headers
Ejemplo n.º 15
0
class Request:
    '''
    The API request object.
    '''

    __slots__ = (
        'config', 'session', 'method', 'path',
        'date', 'headers', 'params', 'content_type',
        '_content', '_attached_files',
        'reporthook',
    )

    _allowed_methods = frozenset([
        'GET', 'HEAD', 'POST',
        'PUT', 'PATCH', 'DELETE',
        'OPTIONS'])

    def __init__(self, session: BaseSession,
                 method: str = 'GET',
                 path: str = None,
                 content: RequestContent = None, *,
                 content_type: str = None,
                 params: Mapping[str, str] = None,
                 reporthook: Callable = None) -> None:
        '''
        Initialize an API request.

        :param BaseSession session: The session where this request is executed on.

        :param str path: The query path. When performing requests, the version number
                         prefix will be automatically perpended if required.

        :param RequestContent content: The API query body which will be encoded as
                                       JSON.

        :param str content_type: Explicitly set the content type.  See also
                                 :func:`Request.set_content`.
        '''
        self.session = session
        self.config = session.config
        self.method = method
        if path.startswith('/'):
            path = path[1:]
        self.path = path
        self.params = params
        self.date = None
        self.headers = CIMultiDict([
            ('User-Agent', self.config.user_agent),
            ('X-BackendAI-Version', self.config.version),
        ])
        self._attached_files = None
        self.set_content(content, content_type=content_type)
        self.reporthook = reporthook

    @property
    def content(self) -> RequestContent:
        '''
        Retrieves the content in the original form.
        Private codes should NOT use this as it incurs duplicate
        encoding/decoding.
        '''
        return self._content

    def set_content(self, value: RequestContent, *,
                    content_type: str = None):
        '''
        Sets the content of the request.
        '''
        assert self._attached_files is None, \
               'cannot set content because you already attached files.'
        guessed_content_type = 'application/octet-stream'
        if value is None:
            guessed_content_type = 'text/plain'
            self._content = b''
        elif isinstance(value, str):
            guessed_content_type = 'text/plain'
            self._content = value.encode('utf-8')
        else:
            guessed_content_type = 'application/octet-stream'
            self._content = value
        self.content_type = (content_type if content_type is not None
                             else guessed_content_type)

    def set_json(self, value: object):
        '''
        A shortcut for set_content() with JSON objects.
        '''
        self.set_content(modjson.dumps(value, cls=ExtendedJSONEncoder),
                         content_type='application/json')

    def attach_files(self, files: Sequence[AttachedFile]):
        '''
        Attach a list of files represented as AttachedFile.
        '''
        assert not self._content, 'content must be empty to attach files.'
        self.content_type = 'multipart/form-data'
        self._attached_files = files

    def _sign(self, rel_url, access_key=None, secret_key=None, hash_type=None):
        '''
        Calculates the signature of the given request and adds the
        Authorization HTTP header.
        It should be called at the very end of request preparation and before
        sending the request to the server.
        '''
        if access_key is None:
            access_key = self.config.access_key
        if secret_key is None:
            secret_key = self.config.secret_key
        if hash_type is None:
            hash_type = self.config.hash_type
        hdrs, _ = generate_signature(
            self.method, self.config.version, self.config.endpoint,
            self.date, str(rel_url), self.content_type, self._content,
            access_key, secret_key, hash_type)
        self.headers.update(hdrs)

    def _pack_content(self):
        if self._attached_files is not None:
            data = aiohttp.FormData()
            for f in self._attached_files:
                data.add_field('src',
                               f.stream,
                               filename=f.filename,
                               content_type=f.content_type)
            assert data.is_multipart, 'Failed to pack files as multipart.'
            # Let aiohttp fill up the content-type header including
            # multipart boundaries.
            self.headers.pop('Content-Type', None)
            return data
        else:
            return self._content

    def _build_url(self):
        base_url = self.config.endpoint.path.rstrip('/')
        query_path = self.path.lstrip('/') if len(self.path) > 0 else ''
        path = '{0}/{1}'.format(base_url, query_path)
        url = self.config.endpoint.with_path(path)
        if self.params:
            url = url.with_query(self.params)
        return url

    # TODO: attach rate-limit information

    def fetch(self, **kwargs) -> 'FetchContextManager':
        '''
        Sends the request to the server and reads the response.

        You may use this method either with plain synchronous Session or
        AsyncSession.  Both the followings patterns are valid:

        .. code-block:: python3

          from ai.backend.client.request import Request
          from ai.backend.client.session import Session

          with Session() as sess:
            rqst = Request(sess, 'GET', ...)
            with rqst.fetch() as resp:
              print(resp.text())

        .. code-block:: python3

          from ai.backend.client.request import Request
          from ai.backend.client.session import AsyncSession

          async with AsyncSession() as sess:
            rqst = Request(sess, 'GET', ...)
            async with rqst.fetch() as resp:
              print(await resp.text())
        '''
        assert self.method in self._allowed_methods, \
               'Disallowed HTTP method: {}'.format(self.method)
        self.date = datetime.now(tzutc())
        self.headers['Date'] = self.date.isoformat()
        if self.content_type is not None:
            self.headers['Content-Type'] = self.content_type
        full_url = self._build_url()
        self._sign(full_url.relative())
        rqst_ctx = self.session.aiohttp_session.request(
            self.method,
            str(full_url),
            data=self._pack_content(),
            timeout=_default_request_timeout,
            headers=self.headers)
        return FetchContextManager(self.session, rqst_ctx, **kwargs)

    def connect_websocket(self, **kwargs) -> 'WebSocketContextManager':
        '''
        Creates a WebSocket connection.

        .. warning::

          This method only works with
          :class:`~ai.backend.client.session.AsyncSession`.
        '''
        assert isinstance(self.session, AsyncSession), \
               'Cannot use websockets with sessions in the synchronous mode'
        assert self.method == 'GET', 'Invalid websocket method'
        self.date = datetime.now(tzutc())
        self.headers['Date'] = self.date.isoformat()
        # websocket is always a "binary" stream.
        self.content_type = 'application/octet-stream'
        full_url = self._build_url()
        self._sign(full_url.relative())
        ws_ctx = self.session.aiohttp_session.ws_connect(
            str(full_url),
            autoping=True, heartbeat=30.0,
            headers=self.headers)
        return WebSocketContextManager(self.session, ws_ctx, **kwargs)
Ejemplo n.º 16
0
    async def request(
        self,
        method: str = 'get',
        url: str = '',
        response_processor=None,
        params: Optional[dict] = None,
        headers: Optional[dict] = None,
        json: Optional[dict] = None,
        data: Optional[dict] = None,
        timeout: Optional[float] = None,
        *args,
        **kwargs,
    ):
        if self.session is None:
            raise RuntimeError('Client should be started before use')
        if response_processor is None:
            response_processor = self.response_processor
        all_params = CIMultiDict(self.default_params)
        if params:
            all_params.update(filter_none(params))
        all_headers = dict(self.default_headers)
        if headers:
            all_headers.update(headers)
        all_headers.update(self.headers(**kwargs))
        full_url = f"{self.base_url}/{url.lstrip('/')}"

        try:
            await self.pre_request_hook()
            if json:
                if data:
                    raise ValueError(
                        'data and json parameters can not be used at the same time'
                    )
                data = orjson.dumps(json)
                if 'Content-Type' not in all_headers:
                    all_headers['Content-Type'] = 'application/json'

            response = await self.session.request(
                method,
                full_url,
                params=params,
                headers=filter_none(all_headers),
                data=data,
                timeout=timeout or self.timeout,
            )
            if response_processor:
                return await response_processor(response)
            return response
        except ClientConnectorError as e:
            if isinstance(e.os_error, socket.gaierror) and (e.errno == -2
                                                            or e.errno == 111):
                await self.session.close()
                self.session = self._create_session()
                raise TemporaryError(url=full_url, nested_error=str(e))
            else:
                raise
        except (
                aiohttp.client_exceptions.ClientOSError,
                aiohttp.client_exceptions.ServerDisconnectedError,
                asyncio.TimeoutError,
                ConnectionAbortedError,
                ProxyError,
                ProxyTimeoutError,
        ) as e:
            await self.session.close()
            self.session = self._create_session()
            raise TemporaryError(url=full_url, nested_error=str(e))
Ejemplo n.º 17
0
class ClientRequest:

    GET_METHODS = {hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS}
    POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
    ALL_METHODS = GET_METHODS.union(POST_METHODS).union({hdrs.METH_DELETE, hdrs.METH_TRACE})

    DEFAULT_HEADERS = {hdrs.ACCEPT: "*/*", hdrs.ACCEPT_ENCODING: "gzip, deflate"}

    SERVER_SOFTWARE = HttpMessage.SERVER_SOFTWARE

    body = b""
    auth = None
    response = None
    response_class = None

    _writer = None  # async task for streaming data
    _continue = None  # waiter future for '100 Continue' response

    # N.B.
    # Adding __del__ method with self._writer closing doesn't make sense
    # because _writer is instance method, thus it keeps a reference to self.
    # Until writer has finished finalizer will not be called.

    def __init__(
        self,
        method,
        url,
        *,
        params=None,
        headers=None,
        skip_auto_headers=frozenset(),
        data=None,
        cookies=None,
        auth=None,
        encoding="utf-8",
        version=aiohttp.HttpVersion11,
        compress=None,
        chunked=None,
        expect100=False,
        loop=None,
        response_class=None
    ):

        if loop is None:
            loop = asyncio.get_event_loop()

        self.url = url
        self.method = method.upper()
        self.encoding = encoding
        self.chunked = chunked
        self.compress = compress
        self.loop = loop
        self.response_class = response_class or ClientResponse

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))

        self.update_version(version)
        self.update_host(url)
        self.update_path(params)
        self.update_headers(headers)
        self.update_auto_headers(skip_auto_headers)
        self.update_cookies(cookies)
        self.update_content_encoding()
        self.update_auth(auth)

        self.update_body_from_data(data, skip_auto_headers)
        self.update_transfer_encoding()
        self.update_expect_continue(expect100)

    def update_host(self, url):
        """Update destination host, port and connection type (ssl)."""
        url_parsed = urllib.parse.urlsplit(url)

        # check for network location part
        netloc = url_parsed.netloc
        if not netloc:
            raise ValueError("Host could not be detected.")

        # get host/port
        host = url_parsed.hostname
        if not host:
            raise ValueError("Host could not be detected.")

        try:
            port = url_parsed.port
        except ValueError:
            raise ValueError("Port number could not be converted.") from None

        # check domain idna encoding
        try:
            netloc = netloc.encode("idna").decode("utf-8")
            host = host.encode("idna").decode("utf-8")
        except UnicodeError:
            raise ValueError("URL has an invalid label.")

        # basic auth info
        username, password = url_parsed.username, url_parsed.password
        if username:
            self.auth = helpers.BasicAuth(username, password or "")
            netloc = netloc.split("@", 1)[1]

        # Record entire netloc for usage in host header
        self.netloc = netloc

        scheme = url_parsed.scheme
        self.ssl = scheme in ("https", "wss")

        # set port number if it isn't already set
        if not port:
            if self.ssl:
                port = HTTPS_PORT
            else:
                port = HTTP_PORT

        self.host, self.port, self.scheme = host, port, scheme

    def update_version(self, version):
        """Convert request version to two elements tuple.

        parser HTTP version '1.1' => (1, 1)
        """
        if isinstance(version, str):
            v = [l.strip() for l in version.split(".", 1)]
            try:
                version = int(v[0]), int(v[1])
            except ValueError:
                raise ValueError("Can not parse http version number: {}".format(version)) from None
        self.version = version

    def update_path(self, params):
        """Build path."""
        # extract path
        scheme, netloc, path, query, fragment = urllib.parse.urlsplit(self.url)
        if not path:
            path = "/"

        if isinstance(params, collections.Mapping):
            params = list(params.items())

        if params:
            if not isinstance(params, str):
                params = urllib.parse.urlencode(params)
            if query:
                query = "%s&%s" % (query, params)
            else:
                query = params

        self.path = urllib.parse.urlunsplit(("", "", helpers.requote_uri(path), query, fragment))
        self.url = urllib.parse.urlunsplit((scheme, netloc, self.path, "", ""))

    def update_headers(self, headers):
        """Update request headers."""
        self.headers = CIMultiDict()
        if headers:
            if isinstance(headers, dict):
                headers = headers.items()
            elif isinstance(headers, (MultiDictProxy, MultiDict)):
                headers = headers.items()

            for key, value in headers:
                self.headers.add(key, value)

    def update_auto_headers(self, skip_auto_headers):
        self.skip_auto_headers = skip_auto_headers
        used_headers = set(self.headers) | skip_auto_headers

        for hdr, val in self.DEFAULT_HEADERS.items():
            if hdr not in used_headers:
                self.headers.add(hdr, val)

        # add host
        if hdrs.HOST not in used_headers:
            self.headers[hdrs.HOST] = self.netloc

        if hdrs.USER_AGENT not in used_headers:
            self.headers[hdrs.USER_AGENT] = self.SERVER_SOFTWARE

    def update_cookies(self, cookies):
        """Update request cookies header."""
        if not cookies:
            return

        c = http.cookies.SimpleCookie()
        if hdrs.COOKIE in self.headers:
            c.load(self.headers.get(hdrs.COOKIE, ""))
            del self.headers[hdrs.COOKIE]

        if isinstance(cookies, dict):
            cookies = cookies.items()

        for name, value in cookies:
            if isinstance(value, http.cookies.Morsel):
                c[value.key] = value.value
            else:
                c[name] = value

        self.headers[hdrs.COOKIE] = c.output(header="", sep=";").strip()

    def update_content_encoding(self):
        """Set request content encoding."""
        enc = self.headers.get(hdrs.CONTENT_ENCODING, "").lower()
        if enc:
            if self.compress is not False:
                self.compress = enc
                # enable chunked, no need to deal with length
                self.chunked = True
        elif self.compress:
            if not isinstance(self.compress, str):
                self.compress = "deflate"
            self.headers[hdrs.CONTENT_ENCODING] = self.compress
            self.chunked = True  # enable chunked, no need to deal with length

    def update_auth(self, auth):
        """Set basic auth."""
        if auth is None:
            auth = self.auth
        if auth is None:
            return

        if not isinstance(auth, helpers.BasicAuth):
            warnings.warn("BasicAuth() tuple is required instead ", DeprecationWarning)
            auth = helpers.BasicAuth(*auth)

        self.headers[hdrs.AUTHORIZATION] = auth.encode()

    def update_body_from_data(self, data, skip_auto_headers):
        if not data:
            return

        if isinstance(data, str):
            data = data.encode(self.encoding)

        if isinstance(data, (bytes, bytearray)):
            self.body = data
            if hdrs.CONTENT_TYPE not in self.headers and hdrs.CONTENT_TYPE not in skip_auto_headers:
                self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream"
            if hdrs.CONTENT_LENGTH not in self.headers and not self.chunked:
                self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

        elif isinstance(data, (asyncio.StreamReader, streams.DataQueue)):
            self.body = data

        elif asyncio.iscoroutine(data):
            self.body = data
            if hdrs.CONTENT_LENGTH not in self.headers and self.chunked is None:
                self.chunked = True

        elif isinstance(data, io.IOBase):
            assert not isinstance(data, io.StringIO), "attempt to send text data instead of binary"
            self.body = data
            if not self.chunked and isinstance(data, io.BytesIO):
                # Not chunking if content-length can be determined
                size = len(data.getbuffer())
                self.headers[hdrs.CONTENT_LENGTH] = str(size)
                self.chunked = False
            elif not self.chunked and isinstance(data, io.BufferedReader):
                # Not chunking if content-length can be determined
                try:
                    size = os.fstat(data.fileno()).st_size - data.tell()
                    self.headers[hdrs.CONTENT_LENGTH] = str(size)
                    self.chunked = False
                except OSError:
                    # data.fileno() is not supported, e.g.
                    # io.BufferedReader(io.BytesIO(b'data'))
                    self.chunked = True
            else:
                self.chunked = True

            if hasattr(data, "mode"):
                if data.mode == "r":
                    raise ValueError("file {!r} should be open in binary mode" "".format(data))
            if (
                hdrs.CONTENT_TYPE not in self.headers
                and hdrs.CONTENT_TYPE not in skip_auto_headers
                and hasattr(data, "name")
            ):
                mime = mimetypes.guess_type(data.name)[0]
                mime = "application/octet-stream" if mime is None else mime
                self.headers[hdrs.CONTENT_TYPE] = mime

        elif isinstance(data, MultipartWriter):
            self.body = data.serialize()
            self.headers.update(data.headers)
            self.chunked = self.chunked or 8192

        else:
            if not isinstance(data, helpers.FormData):
                data = helpers.FormData(data)

            self.body = data(self.encoding)

            if hdrs.CONTENT_TYPE not in self.headers and hdrs.CONTENT_TYPE not in skip_auto_headers:
                self.headers[hdrs.CONTENT_TYPE] = data.content_type

            if data.is_multipart:
                self.chunked = self.chunked or 8192
            else:
                if hdrs.CONTENT_LENGTH not in self.headers and not self.chunked:
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_transfer_encoding(self):
        """Analyze transfer-encoding header."""
        te = self.headers.get(hdrs.TRANSFER_ENCODING, "").lower()

        if self.chunked:
            if hdrs.CONTENT_LENGTH in self.headers:
                del self.headers[hdrs.CONTENT_LENGTH]
            if "chunked" not in te:
                self.headers[hdrs.TRANSFER_ENCODING] = "chunked"

            self.chunked = self.chunked if type(self.chunked) is int else 8192
        else:
            if "chunked" in te:
                self.chunked = 8192
            else:
                self.chunked = None
                if hdrs.CONTENT_LENGTH not in self.headers:
                    self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))

    def update_expect_continue(self, expect=False):
        if expect:
            self.headers[hdrs.EXPECT] = "100-continue"
        elif self.headers.get(hdrs.EXPECT, "").lower() == "100-continue":
            expect = True

        if expect:
            self._continue = asyncio.Future(loop=self.loop)

    @asyncio.coroutine
    def write_bytes(self, request, reader):
        """Support coroutines that yields bytes objects."""
        # 100 response
        if self._continue is not None:
            yield from self._continue

        try:
            if asyncio.iscoroutine(self.body):
                request.transport.set_tcp_nodelay(True)
                exc = None
                value = None
                stream = self.body

                while True:
                    try:
                        if exc is not None:
                            result = stream.throw(exc)
                        else:
                            result = stream.send(value)
                    except StopIteration as exc:
                        if isinstance(exc.value, bytes):
                            yield from request.write(exc.value, drain=True)
                        break
                    except:
                        self.response.close()
                        raise

                    if isinstance(result, asyncio.Future):
                        exc = None
                        value = None
                        try:
                            value = yield result
                        except Exception as err:
                            exc = err
                    elif isinstance(result, (bytes, bytearray)):
                        yield from request.write(result, drain=True)
                        value = None
                    else:
                        raise ValueError("Bytes object is expected, got: %s." % type(result))

            elif isinstance(self.body, asyncio.StreamReader):
                request.transport.set_tcp_nodelay(True)
                chunk = yield from self.body.read(streams.DEFAULT_LIMIT)
                while chunk:
                    yield from request.write(chunk, drain=True)
                    chunk = yield from self.body.read(streams.DEFAULT_LIMIT)

            elif isinstance(self.body, streams.DataQueue):
                request.transport.set_tcp_nodelay(True)
                while True:
                    try:
                        chunk = yield from self.body.read()
                        if chunk is EOF_MARKER:
                            break
                        yield from request.write(chunk, drain=True)
                    except streams.EofStream:
                        break

            elif isinstance(self.body, io.IOBase):
                chunk = self.body.read(self.chunked)
                while chunk:
                    request.write(chunk)
                    chunk = self.body.read(self.chunked)
                request.transport.set_tcp_nodelay(True)

            else:
                if isinstance(self.body, (bytes, bytearray)):
                    self.body = (self.body,)

                for chunk in self.body:
                    request.write(chunk)
                request.transport.set_tcp_nodelay(True)

        except Exception as exc:
            new_exc = aiohttp.ClientRequestError("Can not write request body for %s" % self.url)
            new_exc.__context__ = exc
            new_exc.__cause__ = exc
            reader.set_exception(new_exc)
        else:
            assert request.transport.tcp_nodelay
            try:
                ret = request.write_eof()
                # NB: in asyncio 3.4.1+ StreamWriter.drain() is coroutine
                # see bug #170
                if asyncio.iscoroutine(ret) or isinstance(ret, asyncio.Future):
                    yield from ret
            except Exception as exc:
                new_exc = aiohttp.ClientRequestError("Can not write request body for %s" % self.url)
                new_exc.__context__ = exc
                new_exc.__cause__ = exc
                reader.set_exception(new_exc)

        self._writer = None

    def send(self, writer, reader):
        writer.set_tcp_cork(True)
        request = aiohttp.Request(writer, self.method, self.path, self.version)

        if self.compress:
            request.add_compression_filter(self.compress)

        if self.chunked is not None:
            request.enable_chunked_encoding()
            request.add_chunking_filter(self.chunked)

        # set default content-type
        if (
            self.method in self.POST_METHODS
            and hdrs.CONTENT_TYPE not in self.skip_auto_headers
            and hdrs.CONTENT_TYPE not in self.headers
        ):
            self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream"

        for k, value in self.headers.items():
            request.add_header(k, value)
        request.send_headers()

        self._writer = helpers.ensure_future(self.write_bytes(request, reader), loop=self.loop)

        self.response = self.response_class(
            self.method, self.url, self.host, writer=self._writer, continue100=self._continue
        )
        self.response._post_init(self.loop)
        return self.response

    @asyncio.coroutine
    def close(self):
        if self._writer is not None:
            try:
                yield from self._writer
            finally:
                self._writer = None

    def terminate(self):
        if self._writer is not None:
            if hasattr(self.loop, "is_closed"):
                if not self.loop.is_closed():
                    self._writer.cancel()
            else:
                self._writer.cancel()
            self._writer = None
Ejemplo n.º 18
0
class BaseRequest:

    __slots__ = [
        'config', 'method', 'path', 'date', 'headers', 'content_type',
        '_content'
    ]

    _allowed_methods = frozenset(
        ['GET', 'HEAD', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])

    def __init__(self,
                 method: str = 'GET',
                 path: Optional[str] = None,
                 content: Optional[Mapping] = None,
                 config: Optional[APIConfig] = None) -> None:
        self.config = config if config else get_config()
        self.method = method
        if path.startswith('/'):
            path = path[1:]
        self.path = path
        self.date = datetime.now(tzutc())
        self.headers = CIMultiDict([
            ('Date', self.date.isoformat()),
            ('X-BackendAI-Version', self.config.version),
        ])
        self.content = content if content is not None else b''

    @property
    def content(self) -> Union[bytes, bytearray, None]:
        '''
        Retrieves the content in the original form.
        Private codes should NOT use this as it incurs duplicate
        encoding/decoding.
        '''
        if self._content is None:
            raise ValueError('content is not set.')
        if self.content_type == 'application/octet-stream':
            return self._content
        elif self.content_type == 'application/json':
            return json.loads(self._content.decode('utf-8'),
                              object_pairs_hook=OrderedDict)
        elif self.content_type == 'text/plain':
            return self._content.decode('utf-8')
        elif self.content_type == 'multipart/form-data':
            return self._content
        else:
            raise RuntimeError('should not reach here')  # pragma: no cover

    @content.setter
    def content(self, value: Union[bytes, bytearray, Mapping[str, Any],
                                   Sequence[Any], None]):
        '''
        Sets the content of the request.
        Depending on the type of content, it automatically sets appropriate
        headers such as content-type and content-length.
        '''
        if isinstance(value, (bytes, bytearray)):
            self.content_type = 'application/octet-stream'
            self._content = value
            self.headers['Content-Type'] = self.content_type
            self.headers['Content-Length'] = str(len(self._content))
        elif isinstance(value, str):
            self.content_type = 'text/plain'
            self._content = value.encode('utf-8')
            self.headers['Content-Type'] = self.content_type
            self.headers['Content-Length'] = str(len(self._content))
        elif isinstance(value, (dict, OrderedDict)):
            self.content_type = 'application/json'
            self._content = json.dumps(value).encode('utf-8')
            self.headers['Content-Type'] = self.content_type
            self.headers['Content-Length'] = str(len(self._content))
        elif isinstance(value, (list, tuple)):
            self.content_type = 'multipart/form-data'
            self._content = value
            # Let the http client library decide the header values.
            # (e.g., message boundaries)
            if 'Content-Length' in self.headers:
                del self.headers['Content-Length']
            if 'Content-Type' in self.headers:
                del self.headers['Content-Type']
        else:
            raise TypeError('Unsupported content value type.')

    def _sign(self, access_key=None, secret_key=None, hash_type=None):
        '''
        Calculates the signature of the given request and adds the
        Authorization HTTP header.
        It should be called at the very end of request preparation and before
        sending the request to the server.
        '''
        if access_key is None:
            access_key = self.config.access_key
        if secret_key is None:
            secret_key = self.config.secret_key
        if hash_type is None:
            hash_type = self.config.hash_type
        hdrs, _ = generate_signature(self.method, self.config.version,
                                     self.config.endpoint, self.date,
                                     self.path, self.content_type,
                                     self._content, access_key, secret_key,
                                     hash_type)
        self.headers.update(hdrs)

    def build_url(self):
        major_ver = self.config.version.split('.', 1)[0]
        path = '/' + self.path if len(self.path) > 0 else ''
        return urljoin(self.config.endpoint, major_ver + path)

    # TODO: attach rate-limit information

    def send(self, *, sess=None):
        '''
        Sends the request to the server.
        '''
        assert self.method in self._allowed_methods
        if sess is None:
            sess = requests.Session()
        else:
            assert isinstance(sess, requests.Session)
        self._sign()
        reqfunc = getattr(sess, self.method.lower())
        if self.content_type == 'multipart/form-data':
            files = map(
                lambda f: (f.name, (f.filename, f.file, f.content_type)),
                self._content)
            resp = reqfunc(self.build_url(), files=files, headers=self.headers)
        else:
            resp = reqfunc(self.build_url(),
                           data=self._content,
                           headers=self.headers)
        try:
            return Response(resp.status_code, resp.reason, resp.content,
                            resp.headers['content-type'],
                            resp.headers['content-length'])
        except requests.exceptions.RequestException as e:
            msg = 'Request to the API endpoint has failed.'
            raise BackendClientError(msg) from e