Example #1
0
class StreamResponse(collections.MutableMapping, HeadersMixin):

    _length_check = True

    def __init__(self, *, status=200, reason=None, headers=None):
        self._body = None
        self._keep_alive = None
        self._chunked = False
        self._compression = False
        self._compression_force = False
        self._cookies = SimpleCookie()

        self._req = None
        self._payload_writer = None
        self._eof_sent = False
        self._body_length = 0
        self._state = {}

        if headers is not None:
            self._headers = CIMultiDict(headers)
        else:
            self._headers = CIMultiDict()

        self.set_status(status, reason)

    @property
    def prepared(self):
        return self._payload_writer is not None

    @property
    def task(self):
        return getattr(self._req, 'task', None)

    @property
    def status(self):
        return self._status

    @property
    def chunked(self):
        return self._chunked

    @property
    def compression(self):
        return self._compression

    @property
    def reason(self):
        return self._reason

    def set_status(self, status, reason=None, _RESPONSES=RESPONSES):
        assert not self.prepared, \
            'Cannot change the response status code after ' \
            'the headers have been sent'
        self._status = int(status)
        if reason is None:
            try:
                reason = _RESPONSES[self._status][0]
            except Exception:
                reason = ''
        self._reason = reason

    @property
    def keep_alive(self):
        return self._keep_alive

    def force_close(self):
        self._keep_alive = False

    @property
    def body_length(self):
        return self._body_length

    @property
    def output_length(self):
        warnings.warn('output_length is deprecated', DeprecationWarning)
        return self._payload_writer.buffer_size

    def enable_chunked_encoding(self, chunk_size=None):
        """Enables automatic chunked transfer encoding."""
        self._chunked = True

        if hdrs.CONTENT_LENGTH in self._headers:
            raise RuntimeError("You can't enable chunked encoding when "
                               "a content length is set")
        if chunk_size is not None:
            warnings.warn('Chunk size is deprecated #1615', DeprecationWarning)

    def enable_compression(self, force=None):
        """Enables response compression encoding."""
        # Backwards compatibility for when force was a bool <0.17.
        if type(force) == bool:
            force = ContentCoding.deflate if force else ContentCoding.identity
        elif force is not None:
            assert isinstance(force, ContentCoding), ("force should one of "
                                                      "None, bool or "
                                                      "ContentEncoding")

        self._compression = True
        self._compression_force = force

    @property
    def headers(self):
        return self._headers

    @property
    def cookies(self):
        return self._cookies

    def set_cookie(self, name, value, *, expires=None,
                   domain=None, max_age=None, path='/',
                   secure=None, httponly=None, version=None):
        """Set or update response cookie.

        Sets new cookie or updates existent with new value.
        Also updates only those params which are not None.
        """

        old = self._cookies.get(name)
        if old is not None and old.coded_value == '':
            # deleted cookie
            self._cookies.pop(name, None)

        self._cookies[name] = value
        c = self._cookies[name]

        if expires is not None:
            c['expires'] = expires
        elif c.get('expires') == 'Thu, 01 Jan 1970 00:00:00 GMT':
            del c['expires']

        if domain is not None:
            c['domain'] = domain

        if max_age is not None:
            c['max-age'] = max_age
        elif 'max-age' in c:
            del c['max-age']

        c['path'] = path

        if secure is not None:
            c['secure'] = secure
        if httponly is not None:
            c['httponly'] = httponly
        if version is not None:
            c['version'] = version

    def del_cookie(self, name, *, domain=None, path='/'):
        """Delete cookie.

        Creates new empty expired cookie.
        """
        # TODO: do we need domain/path here?
        self._cookies.pop(name, None)
        self.set_cookie(name, '', max_age=0,
                        expires="Thu, 01 Jan 1970 00:00:00 GMT",
                        domain=domain, path=path)

    @property
    def content_length(self):
        # Just a placeholder for adding setter
        return super().content_length

    @content_length.setter
    def content_length(self, value):
        if value is not None:
            value = int(value)
            if self._chunked:
                raise RuntimeError("You can't set content length when "
                                   "chunked encoding is enable")
            self._headers[hdrs.CONTENT_LENGTH] = str(value)
        else:
            self._headers.pop(hdrs.CONTENT_LENGTH, None)

    @property
    def content_type(self):
        # Just a placeholder for adding setter
        return super().content_type

    @content_type.setter
    def content_type(self, value):
        self.content_type  # read header values if needed
        self._content_type = str(value)
        self._generate_content_type_header()

    @property
    def charset(self):
        # Just a placeholder for adding setter
        return super().charset

    @charset.setter
    def charset(self, value):
        ctype = self.content_type  # read header values if needed
        if ctype == 'application/octet-stream':
            raise RuntimeError("Setting charset for application/octet-stream "
                               "doesn't make sense, setup content_type first")
        if value is None:
            self._content_dict.pop('charset', None)
        else:
            self._content_dict['charset'] = str(value).lower()
        self._generate_content_type_header()

    @property
    def last_modified(self, _LAST_MODIFIED=hdrs.LAST_MODIFIED):
        """The value of Last-Modified HTTP header, or None.

        This header is represented as a `datetime` object.
        """
        httpdate = self.headers.get(_LAST_MODIFIED)
        if httpdate is not None:
            timetuple = parsedate(httpdate)
            if timetuple is not None:
                return datetime.datetime(*timetuple[:6],
                                         tzinfo=datetime.timezone.utc)
        return None

    @last_modified.setter
    def last_modified(self, value):
        if value is None:
            self.headers.pop(hdrs.LAST_MODIFIED, None)
        elif isinstance(value, (int, float)):
            self.headers[hdrs.LAST_MODIFIED] = time.strftime(
                "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value)))
        elif isinstance(value, datetime.datetime):
            self.headers[hdrs.LAST_MODIFIED] = time.strftime(
                "%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple())
        elif isinstance(value, str):
            self.headers[hdrs.LAST_MODIFIED] = value

    def _generate_content_type_header(self, CONTENT_TYPE=hdrs.CONTENT_TYPE):
        params = '; '.join("%s=%s" % i for i in self._content_dict.items())
        if params:
            ctype = self._content_type + '; ' + params
        else:
            ctype = self._content_type
        self.headers[CONTENT_TYPE] = ctype

    def _do_start_compression(self, coding):
        if coding != ContentCoding.identity:
            self.headers[hdrs.CONTENT_ENCODING] = coding.value
            self._payload_writer.enable_compression(coding.value)
            # Compressed payload may have different content length,
            # remove the header
            self._headers.popall(hdrs.CONTENT_LENGTH, None)

    def _start_compression(self, request):
        if self._compression_force:
            self._do_start_compression(self._compression_force)
        else:
            accept_encoding = request.headers.get(
                hdrs.ACCEPT_ENCODING, '').lower()
            for coding in ContentCoding:
                if coding.value in accept_encoding:
                    self._do_start_compression(coding)
                    return

    async def prepare(self, request):
        if self._eof_sent:
            return
        if self._payload_writer is not None:
            return self._payload_writer

        await request._prepare_hook(self)
        return self._start(request)

    def _start(self, request,
               HttpVersion10=HttpVersion10,
               HttpVersion11=HttpVersion11,
               CONNECTION=hdrs.CONNECTION,
               DATE=hdrs.DATE,
               SERVER=hdrs.SERVER,
               CONTENT_TYPE=hdrs.CONTENT_TYPE,
               CONTENT_LENGTH=hdrs.CONTENT_LENGTH,
               SET_COOKIE=hdrs.SET_COOKIE,
               SERVER_SOFTWARE=SERVER_SOFTWARE,
               TRANSFER_ENCODING=hdrs.TRANSFER_ENCODING):
        self._req = request

        keep_alive = self._keep_alive
        if keep_alive is None:
            keep_alive = request.keep_alive
        self._keep_alive = keep_alive

        version = request.version
        writer = self._payload_writer = request._payload_writer

        headers = self._headers
        for cookie in self._cookies.values():
            value = cookie.output(header='')[1:]
            headers.add(SET_COOKIE, value)

        if self._compression:
            self._start_compression(request)

        if self._chunked:
            if version != HttpVersion11:
                raise RuntimeError(
                    "Using chunked encoding is forbidden "
                    "for HTTP/{0.major}.{0.minor}".format(request.version))
            writer.enable_chunking()
            headers[TRANSFER_ENCODING] = 'chunked'
            if CONTENT_LENGTH in headers:
                del headers[CONTENT_LENGTH]
        elif self._length_check:
            writer.length = self.content_length
            if writer.length is None:
                if version >= HttpVersion11:
                    writer.enable_chunking()
                    headers[TRANSFER_ENCODING] = 'chunked'
                    if CONTENT_LENGTH in headers:
                        del headers[CONTENT_LENGTH]
                else:
                    keep_alive = False

        headers.setdefault(CONTENT_TYPE, 'application/octet-stream')
        headers.setdefault(DATE, rfc822_formatted_time())
        headers.setdefault(SERVER, SERVER_SOFTWARE)

        # connection header
        if CONNECTION not in headers:
            if keep_alive:
                if version == HttpVersion10:
                    headers[CONNECTION] = 'keep-alive'
            else:
                if version == HttpVersion11:
                    headers[CONNECTION] = 'close'

        # status line
        status_line = 'HTTP/{}.{} {} {}\r\n'.format(
            version[0], version[1], self._status, self._reason)
        writer.write_headers(status_line, headers)

        return writer

    async def write(self, data):
        assert isinstance(data, (bytes, bytearray, memoryview)), \
            "data argument must be byte-ish (%r)" % type(data)

        if self._eof_sent:
            raise RuntimeError("Cannot call write() after write_eof()")
        if self._payload_writer is None:
            raise RuntimeError("Cannot call write() before prepare()")

        await self._payload_writer.write(data)

    async def drain(self):
        assert not self._eof_sent, "EOF has already been sent"
        assert self._payload_writer is not None, \
            "Response has not been started"
        warnings.warn("drain method is deprecated, use await resp.write()",
                      DeprecationWarning,
                      stacklevel=2)
        await self._payload_writer.drain()

    async def write_eof(self, data=b''):
        assert isinstance(data, (bytes, bytearray, memoryview)), \
            "data argument must be byte-ish (%r)" % type(data)

        if self._eof_sent:
            return

        assert self._payload_writer is not None, \
            "Response has not been started"

        await self._payload_writer.write_eof(data)
        self._eof_sent = True
        self._req = None
        self._body_length = self._payload_writer.output_size
        self._payload_writer = None

    def __repr__(self):
        if self._eof_sent:
            info = "eof"
        elif self.prepared:
            info = "{} {} ".format(self._req.method, self._req.path)
        else:
            info = "not prepared"
        return "<{} {} {}>".format(self.__class__.__name__,
                                   self.reason, info)

    def __getitem__(self, key):
        return self._state[key]

    def __setitem__(self, key, value):
        self._state[key] = value

    def __delitem__(self, key):
        del self._state[key]

    def __len__(self):
        return len(self._state)

    def __iter__(self):
        return iter(self._state)

    def __hash__(self):
        return hash(id(self))
Example #2
0
class StreamResponse(BaseClass, HeadersMixin, CookieMixin):

    __slots__ = (
        "_length_check",
        "_body",
        "_keep_alive",
        "_chunked",
        "_compression",
        "_compression_force",
        "_req",
        "_payload_writer",
        "_eof_sent",
        "_body_length",
        "_state",
        "_headers",
        "_status",
        "_reason",
        "__weakref__",
    )

    def __init__(
        self,
        *,
        status: int = 200,
        reason: Optional[str] = None,
        headers: Optional[LooseHeaders] = None,
    ) -> None:
        super().__init__()
        self._length_check = True
        self._body = None
        self._keep_alive: Optional[bool] = None
        self._chunked = False
        self._compression = False
        self._compression_force: Optional[ContentCoding] = None

        self._req: Optional[BaseRequest] = None
        self._payload_writer: Optional[AbstractStreamWriter] = None
        self._eof_sent = False
        self._body_length = 0
        self._state: Dict[str, Any] = {}

        if headers is not None:
            self._headers: CIMultiDict[str] = CIMultiDict(headers)
        else:
            self._headers = CIMultiDict()

        self.set_status(status, reason)

    @property
    def prepared(self) -> bool:
        return self._payload_writer is not None

    @property
    def task(self) -> "Optional[asyncio.Task[None]]":
        if self._req:
            return self._req.task
        else:
            return None

    @property
    def status(self) -> int:
        return self._status

    @property
    def chunked(self) -> bool:
        return self._chunked

    @property
    def compression(self) -> bool:
        return self._compression

    @property
    def reason(self) -> str:
        return self._reason

    def set_status(
        self,
        status: int,
        reason: Optional[str] = None,
        _RESPONSES: Mapping[int, Tuple[str, str]] = RESPONSES,
    ) -> None:
        assert not self.prepared, (
            "Cannot change the response status code after "
            "the headers have been sent")
        self._status = int(status)
        if reason is None:
            try:
                reason = _RESPONSES[self._status][0]
            except Exception:
                reason = ""
        self._reason = reason

    @property
    def keep_alive(self) -> Optional[bool]:
        return self._keep_alive

    def force_close(self) -> None:
        self._keep_alive = False

    @property
    def body_length(self) -> int:
        return self._body_length

    def enable_chunked_encoding(self) -> None:
        """Enables automatic chunked transfer encoding."""
        self._chunked = True

        if hdrs.CONTENT_LENGTH in self._headers:
            raise RuntimeError("You can't enable chunked encoding when "
                               "a content length is set")

    def enable_compression(self,
                           force: Optional[ContentCoding] = None) -> None:
        """Enables response compression encoding."""
        # Backwards compatibility for when force was a bool <0.17.
        self._compression = True
        self._compression_force = force

    @property
    def headers(self) -> "CIMultiDict[str]":
        return self._headers

    @property
    def content_length(self) -> Optional[int]:
        # Just a placeholder for adding setter
        return super().content_length

    @content_length.setter
    def content_length(self, value: Optional[int]) -> None:
        if value is not None:
            value = int(value)
            if self._chunked:
                raise RuntimeError("You can't set content length when "
                                   "chunked encoding is enable")
            self._headers[hdrs.CONTENT_LENGTH] = str(value)
        else:
            self._headers.pop(hdrs.CONTENT_LENGTH, None)

    @property
    def content_type(self) -> str:
        # Just a placeholder for adding setter
        return super().content_type

    @content_type.setter
    def content_type(self, value: str) -> None:
        self.content_type  # read header values if needed
        self._content_type = str(value)
        self._generate_content_type_header()

    @property
    def charset(self) -> Optional[str]:
        # Just a placeholder for adding setter
        return super().charset

    @charset.setter
    def charset(self, value: Optional[str]) -> None:
        ctype = self.content_type  # read header values if needed
        if ctype == "application/octet-stream":
            raise RuntimeError("Setting charset for application/octet-stream "
                               "doesn't make sense, setup content_type first")
        assert self._content_dict is not None
        if value is None:
            self._content_dict.pop("charset", None)
        else:
            self._content_dict["charset"] = str(value).lower()
        self._generate_content_type_header()

    @property
    def last_modified(self) -> Optional[datetime.datetime]:
        """The value of Last-Modified HTTP header, or None.

        This header is represented as a `datetime` object.
        """
        return parse_http_date(self._headers.get(hdrs.LAST_MODIFIED))

    @last_modified.setter
    def last_modified(
            self, value: Optional[Union[int, float, datetime.datetime,
                                        str]]) -> None:
        if value is None:
            self._headers.pop(hdrs.LAST_MODIFIED, None)
        elif isinstance(value, (int, float)):
            self._headers[hdrs.LAST_MODIFIED] = time.strftime(
                "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value)))
        elif isinstance(value, datetime.datetime):
            self._headers[hdrs.LAST_MODIFIED] = time.strftime(
                "%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple())
        elif isinstance(value, str):
            self._headers[hdrs.LAST_MODIFIED] = value

    @property
    def etag(self) -> Optional[ETag]:
        quoted_value = self._headers.get(hdrs.ETAG)
        if not quoted_value:
            return None
        elif quoted_value == ETAG_ANY:
            return ETag(value=ETAG_ANY)
        match = QUOTED_ETAG_RE.fullmatch(quoted_value)
        if not match:
            return None
        is_weak, value = match.group(1, 2)
        return ETag(
            is_weak=bool(is_weak),
            value=value,
        )

    @etag.setter
    def etag(self, value: Optional[Union[ETag, str]]) -> None:
        if value is None:
            self._headers.pop(hdrs.ETAG, None)
        elif (isinstance(value, str)
              and value == ETAG_ANY) or (isinstance(value, ETag)
                                         and value.value == ETAG_ANY):
            self._headers[hdrs.ETAG] = ETAG_ANY
        elif isinstance(value, str):
            validate_etag_value(value)
            self._headers[hdrs.ETAG] = f'"{value}"'
        elif isinstance(value, ETag) and isinstance(value.value, str):
            validate_etag_value(value.value)
            hdr_value = f'W/"{value.value}"' if value.is_weak else f'"{value.value}"'
            self._headers[hdrs.ETAG] = hdr_value
        else:
            raise ValueError(f"Unsupported etag type: {type(value)}. "
                             f"etag must be str, ETag or None")

    def _generate_content_type_header(self,
                                      CONTENT_TYPE: istr = hdrs.CONTENT_TYPE
                                      ) -> None:
        assert self._content_dict is not None
        assert self._content_type is not None
        params = "; ".join(f"{k}={v}" for k, v in self._content_dict.items())
        if params:
            ctype = self._content_type + "; " + params
        else:
            ctype = self._content_type
        self._headers[CONTENT_TYPE] = ctype

    async def _do_start_compression(self, coding: ContentCoding) -> None:
        if coding != ContentCoding.identity:
            assert self._payload_writer is not None
            self._headers[hdrs.CONTENT_ENCODING] = coding.value
            self._payload_writer.enable_compression(coding.value)
            # Compressed payload may have different content length,
            # remove the header
            self._headers.popall(hdrs.CONTENT_LENGTH, None)

    async def _start_compression(self, request: "BaseRequest") -> None:
        if self._compression_force:
            await self._do_start_compression(self._compression_force)
        else:
            accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING,
                                                  "").lower()
            for coding in ContentCoding:
                if coding.value in accept_encoding:
                    await self._do_start_compression(coding)
                    return

    async def prepare(
            self, request: "BaseRequest") -> Optional[AbstractStreamWriter]:
        if self._eof_sent:
            return None
        if self._payload_writer is not None:
            return self._payload_writer

        return await self._start(request)

    async def _start(self, request: "BaseRequest") -> AbstractStreamWriter:
        self._req = request
        writer = self._payload_writer = request._payload_writer

        await self._prepare_headers()
        await request._prepare_hook(self)
        await self._write_headers()

        return writer

    async def _prepare_headers(self) -> None:
        request = self._req
        assert request is not None
        writer = self._payload_writer
        assert writer is not None
        keep_alive = self._keep_alive
        if keep_alive is None:
            keep_alive = request.keep_alive
        self._keep_alive = keep_alive

        version = request.version

        headers = self._headers
        populate_with_cookies(headers, self.cookies)

        if self._compression:
            await self._start_compression(request)

        if self._chunked:
            if version != HttpVersion11:
                raise RuntimeError("Using chunked encoding is forbidden "
                                   "for HTTP/{0.major}.{0.minor}".format(
                                       request.version))
            writer.enable_chunking()
            headers[hdrs.TRANSFER_ENCODING] = "chunked"
            if hdrs.CONTENT_LENGTH in headers:
                del headers[hdrs.CONTENT_LENGTH]
        elif self._length_check:
            writer.length = self.content_length
            if writer.length is None:
                if version >= HttpVersion11 and self.status != 204:
                    writer.enable_chunking()
                    headers[hdrs.TRANSFER_ENCODING] = "chunked"
                    if hdrs.CONTENT_LENGTH in headers:
                        del headers[hdrs.CONTENT_LENGTH]
                else:
                    keep_alive = False
            # HTTP 1.1: https://tools.ietf.org/html/rfc7230#section-3.3.2
            # HTTP 1.0: https://tools.ietf.org/html/rfc1945#section-10.4
            elif version >= HttpVersion11 and self.status in (100, 101, 102,
                                                              103, 204):
                del headers[hdrs.CONTENT_LENGTH]

        if self.status not in (204, 304):
            headers.setdefault(hdrs.CONTENT_TYPE, "application/octet-stream")
        headers.setdefault(hdrs.DATE, rfc822_formatted_time())
        headers.setdefault(hdrs.SERVER, SERVER_SOFTWARE)

        # connection header
        if hdrs.CONNECTION not in headers:
            if keep_alive:
                if version == HttpVersion10:
                    headers[hdrs.CONNECTION] = "keep-alive"
            else:
                if version == HttpVersion11:
                    headers[hdrs.CONNECTION] = "close"

    async def _write_headers(self) -> None:
        request = self._req
        assert request is not None
        writer = self._payload_writer
        assert writer is not None
        # status line
        version = request.version
        status_line = "HTTP/{}.{} {} {}".format(version[0], version[1],
                                                self._status, self._reason)
        await writer.write_headers(status_line, self._headers)

    async def write(self, data: bytes) -> None:
        assert isinstance(
            data,
            (bytes, bytearray,
             memoryview)), "data argument must be byte-ish (%r)" % type(data)

        if self._eof_sent:
            raise RuntimeError("Cannot call write() after write_eof()")
        if self._payload_writer is None:
            raise RuntimeError("Cannot call write() before prepare()")

        await self._payload_writer.write(data)

    async def drain(self) -> None:
        assert not self._eof_sent, "EOF has already been sent"
        assert self._payload_writer is not None, "Response has not been started"
        warnings.warn(
            "drain method is deprecated, use await resp.write()",
            DeprecationWarning,
            stacklevel=2,
        )
        await self._payload_writer.drain()

    async def write_eof(self, data: bytes = b"") -> None:
        assert isinstance(
            data,
            (bytes, bytearray,
             memoryview)), "data argument must be byte-ish (%r)" % type(data)

        if self._eof_sent:
            return

        assert self._payload_writer is not None, "Response has not been started"

        await self._payload_writer.write_eof(data)
        self._eof_sent = True
        self._req = None
        self._body_length = self._payload_writer.output_size
        self._payload_writer = None

    def __repr__(self) -> str:
        if self._eof_sent:
            info = "eof"
        elif self.prepared:
            assert self._req is not None
            info = f"{self._req.method} {self._req.path} "
        else:
            info = "not prepared"
        return f"<{self.__class__.__name__} {self.reason} {info}>"

    def __getitem__(self, key: str) -> Any:
        return self._state[key]

    def __setitem__(self, key: str, value: Any) -> None:
        self._state[key] = value

    def __delitem__(self, key: str) -> None:
        del self._state[key]

    def __len__(self) -> int:
        return len(self._state)

    def __iter__(self) -> Iterator[str]:
        return iter(self._state)

    def __hash__(self) -> int:
        return hash(id(self))

    def __eq__(self, other: object) -> bool:
        return self is other
Example #3
0
class StreamResponse(collections.MutableMapping, HeadersMixin):

    _length_check = True

    def __init__(self, *, status=200, reason=None, headers=None):
        self._body = None
        self._keep_alive = None
        self._chunked = False
        self._compression = False
        self._compression_force = None
        self._cookies = SimpleCookie()

        self._req = None
        self._payload_writer = None
        self._eof_sent = False
        self._body_length = 0
        self._state = {}

        if headers is not None:
            self._headers = CIMultiDict(headers)
        else:
            self._headers = CIMultiDict()

        self.set_status(status, reason)

    @property
    def prepared(self):
        return self._payload_writer is not None

    @property
    def task(self):
        return getattr(self._req, 'task', None)

    @property
    def status(self):
        return self._status

    @property
    def chunked(self):
        return self._chunked

    @property
    def compression(self):
        return self._compression

    @property
    def reason(self):
        return self._reason

    def set_status(self, status, reason=None, _RESPONSES=RESPONSES):
        assert not self.prepared, \
            'Cannot change the response status code after ' \
            'the headers have been sent'
        self._status = int(status)
        if reason is None:
            try:
                reason = _RESPONSES[self._status][0]
            except Exception:
                reason = ''
        self._reason = reason

    @property
    def keep_alive(self):
        return self._keep_alive

    def force_close(self):
        self._keep_alive = False

    @property
    def body_length(self):
        return self._body_length

    @property
    def output_length(self):
        warnings.warn('output_length is deprecated', DeprecationWarning)
        return self._payload_writer.buffer_size

    def enable_chunked_encoding(self, chunk_size=None):
        """Enables automatic chunked transfer encoding."""
        self._chunked = True

        if hdrs.CONTENT_LENGTH in self._headers:
            raise RuntimeError("You can't enable chunked encoding when "
                               "a content length is set")
        if chunk_size is not None:
            warnings.warn('Chunk size is deprecated #1615', DeprecationWarning)

    def enable_compression(self, force=None):
        """Enables response compression encoding."""
        # Backwards compatibility for when force was a bool <0.17.
        if type(force) == bool:
            force = ContentCoding.deflate if force else ContentCoding.identity
        elif force is not None:
            assert isinstance(force, ContentCoding), ("force should one of "
                                                      "None, bool or "
                                                      "ContentEncoding")

        self._compression = True
        self._compression_force = force

    @property
    def headers(self):
        return self._headers

    @property
    def cookies(self):
        return self._cookies

    def set_cookie(self,
                   name,
                   value,
                   *,
                   expires=None,
                   domain=None,
                   max_age=None,
                   path='/',
                   secure=None,
                   httponly=None,
                   version=None):
        """Set or update response cookie.

        Sets new cookie or updates existent with new value.
        Also updates only those params which are not None.
        """

        old = self._cookies.get(name)
        if old is not None and old.coded_value == '':
            # deleted cookie
            self._cookies.pop(name, None)

        self._cookies[name] = value
        c = self._cookies[name]

        if expires is not None:
            c['expires'] = expires
        elif c.get('expires') == 'Thu, 01 Jan 1970 00:00:00 GMT':
            del c['expires']

        if domain is not None:
            c['domain'] = domain

        if max_age is not None:
            c['max-age'] = max_age
        elif 'max-age' in c:
            del c['max-age']

        c['path'] = path

        if secure is not None:
            c['secure'] = secure
        if httponly is not None:
            c['httponly'] = httponly
        if version is not None:
            c['version'] = version

    def del_cookie(self, name, *, domain=None, path='/'):
        """Delete cookie.

        Creates new empty expired cookie.
        """
        # TODO: do we need domain/path here?
        self._cookies.pop(name, None)
        self.set_cookie(name,
                        '',
                        max_age=0,
                        expires="Thu, 01 Jan 1970 00:00:00 GMT",
                        domain=domain,
                        path=path)

    @property
    def content_length(self):
        # Just a placeholder for adding setter
        return super().content_length

    @content_length.setter
    def content_length(self, value):
        if value is not None:
            value = int(value)
            if self._chunked:
                raise RuntimeError("You can't set content length when "
                                   "chunked encoding is enable")
            self._headers[hdrs.CONTENT_LENGTH] = str(value)
        else:
            self._headers.pop(hdrs.CONTENT_LENGTH, None)

    @property
    def content_type(self):
        # Just a placeholder for adding setter
        return super().content_type

    @content_type.setter
    def content_type(self, value):
        self.content_type  # read header values if needed
        self._content_type = str(value)
        self._generate_content_type_header()

    @property
    def charset(self):
        # Just a placeholder for adding setter
        return super().charset

    @charset.setter
    def charset(self, value):
        ctype = self.content_type  # read header values if needed
        if ctype == 'application/octet-stream':
            raise RuntimeError("Setting charset for application/octet-stream "
                               "doesn't make sense, setup content_type first")
        if value is None:
            self._content_dict.pop('charset', None)
        else:
            self._content_dict['charset'] = str(value).lower()
        self._generate_content_type_header()

    @property
    def last_modified(self, _LAST_MODIFIED=hdrs.LAST_MODIFIED):
        """The value of Last-Modified HTTP header, or None.

        This header is represented as a `datetime` object.
        """
        httpdate = self.headers.get(_LAST_MODIFIED)
        if httpdate is not None:
            timetuple = parsedate(httpdate)
            if timetuple is not None:
                return datetime.datetime(*timetuple[:6],
                                         tzinfo=datetime.timezone.utc)
        return None

    @last_modified.setter
    def last_modified(self, value):
        if value is None:
            self.headers.pop(hdrs.LAST_MODIFIED, None)
        elif isinstance(value, (int, float)):
            self.headers[hdrs.LAST_MODIFIED] = time.strftime(
                "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value)))
        elif isinstance(value, datetime.datetime):
            self.headers[hdrs.LAST_MODIFIED] = time.strftime(
                "%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple())
        elif isinstance(value, str):
            self.headers[hdrs.LAST_MODIFIED] = value

    def _generate_content_type_header(self, CONTENT_TYPE=hdrs.CONTENT_TYPE):
        params = '; '.join("%s=%s" % i for i in self._content_dict.items())
        if params:
            ctype = self._content_type + '; ' + params
        else:
            ctype = self._content_type
        self.headers[CONTENT_TYPE] = ctype

    def _do_start_compression(self, coding):
        if coding != ContentCoding.identity:
            self.headers[hdrs.CONTENT_ENCODING] = coding.value
            self._payload_writer.enable_compression(coding.value)
            # Compressed payload may have different content length,
            # remove the header
            self._headers.popall(hdrs.CONTENT_LENGTH, None)

    def _start_compression(self, request):
        if self._compression_force:
            self._do_start_compression(self._compression_force)
        else:
            accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING,
                                                  '').lower()
            for coding in ContentCoding:
                if coding.value in accept_encoding:
                    self._do_start_compression(coding)
                    return

    async def prepare(self, request):
        if self._eof_sent:
            return
        if self._payload_writer is not None:
            return self._payload_writer

        await request._prepare_hook(self)
        return await self._start(request)

    async def _start(self,
                     request,
                     HttpVersion10=HttpVersion10,
                     HttpVersion11=HttpVersion11,
                     CONNECTION=hdrs.CONNECTION,
                     DATE=hdrs.DATE,
                     SERVER=hdrs.SERVER,
                     CONTENT_TYPE=hdrs.CONTENT_TYPE,
                     CONTENT_LENGTH=hdrs.CONTENT_LENGTH,
                     SET_COOKIE=hdrs.SET_COOKIE,
                     SERVER_SOFTWARE=SERVER_SOFTWARE,
                     TRANSFER_ENCODING=hdrs.TRANSFER_ENCODING):
        self._req = request

        keep_alive = self._keep_alive
        if keep_alive is None:
            keep_alive = request.keep_alive
        self._keep_alive = keep_alive

        version = request.version
        writer = self._payload_writer = request._payload_writer

        headers = self._headers
        for cookie in self._cookies.values():
            value = cookie.output(header='')[1:]
            headers.add(SET_COOKIE, value)

        if self._compression:
            self._start_compression(request)

        if self._chunked:
            if version != HttpVersion11:
                raise RuntimeError("Using chunked encoding is forbidden "
                                   "for HTTP/{0.major}.{0.minor}".format(
                                       request.version))
            writer.enable_chunking()
            headers[TRANSFER_ENCODING] = 'chunked'
            if CONTENT_LENGTH in headers:
                del headers[CONTENT_LENGTH]
        elif self._length_check:
            writer.length = self.content_length
            if writer.length is None:
                if version >= HttpVersion11:
                    writer.enable_chunking()
                    headers[TRANSFER_ENCODING] = 'chunked'
                    if CONTENT_LENGTH in headers:
                        del headers[CONTENT_LENGTH]
                else:
                    keep_alive = False

        headers.setdefault(CONTENT_TYPE, 'application/octet-stream')
        headers.setdefault(DATE, rfc822_formatted_time())
        headers.setdefault(SERVER, SERVER_SOFTWARE)

        # connection header
        if CONNECTION not in headers:
            if keep_alive:
                if version == HttpVersion10:
                    headers[CONNECTION] = 'keep-alive'
            else:
                if version == HttpVersion11:
                    headers[CONNECTION] = 'close'

        # status line
        status_line = 'HTTP/{}.{} {} {}\r\n'.format(version[0], version[1],
                                                    self._status, self._reason)
        await writer.write_headers(status_line, headers)

        return writer

    async def write(self, data):
        assert isinstance(data, (bytes, bytearray, memoryview)), \
            "data argument must be byte-ish (%r)" % type(data)

        if self._eof_sent:
            raise RuntimeError("Cannot call write() after write_eof()")
        if self._payload_writer is None:
            raise RuntimeError("Cannot call write() before prepare()")

        await self._payload_writer.write(data)

    async def drain(self):
        assert not self._eof_sent, "EOF has already been sent"
        assert self._payload_writer is not None, \
            "Response has not been started"
        warnings.warn("drain method is deprecated, use await resp.write()",
                      DeprecationWarning,
                      stacklevel=2)
        await self._payload_writer.drain()

    async def write_eof(self, data=b''):
        assert isinstance(data, (bytes, bytearray, memoryview)), \
            "data argument must be byte-ish (%r)" % type(data)

        if self._eof_sent:
            return

        assert self._payload_writer is not None, \
            "Response has not been started"

        await self._payload_writer.write_eof(data)
        self._eof_sent = True
        self._req = None
        self._body_length = self._payload_writer.output_size
        self._payload_writer = None

    def __repr__(self):
        if self._eof_sent:
            info = "eof"
        elif self.prepared:
            info = "{} {} ".format(self._req.method, self._req.path)
        else:
            info = "not prepared"
        return "<{} {} {}>".format(self.__class__.__name__, self.reason, info)

    def __getitem__(self, key):
        return self._state[key]

    def __setitem__(self, key, value):
        self._state[key] = value

    def __delitem__(self, key):
        del self._state[key]

    def __len__(self):
        return len(self._state)

    def __iter__(self):
        return iter(self._state)

    def __hash__(self):
        return hash(id(self))
Example #4
0
class StreamResponse(BaseClass, HeadersMixin):

    _length_check = True

    def __init__(self,
                 *,
                 status: int = 200,
                 reason: Optional[str] = None,
                 headers: Optional[LooseHeaders] = None) -> None:
        self._body = None
        self._keep_alive = None  # type: Optional[bool]
        self._chunked = False
        self._compression = False
        self._compression_force = None  # type: Optional[ContentCoding]
        self._cookies = SimpleCookie()  # type: SimpleCookie[str]

        self._req = None  # type: Optional[BaseRequest]
        self._payload_writer = None  # type: Optional[AbstractStreamWriter]
        self._eof_sent = False
        self._body_length = 0
        self._state = {}  # type: Dict[str, Any]

        if headers is not None:
            self._headers = CIMultiDict(headers)  # type: CIMultiDict[str]
        else:
            self._headers = CIMultiDict()

        self.set_status(status, reason)

    @property
    def prepared(self) -> bool:
        return self._payload_writer is not None

    @property
    def task(self) -> "asyncio.Task[None]":
        return getattr(self._req, "task", None)

    @property
    def status(self) -> int:
        return self._status

    @property
    def chunked(self) -> bool:
        return self._chunked

    @property
    def compression(self) -> bool:
        return self._compression

    @property
    def reason(self) -> str:
        return self._reason

    def set_status(
        self,
        status: int,
        reason: Optional[str] = None,
        _RESPONSES: Mapping[int, Tuple[str, str]] = RESPONSES,
    ) -> None:
        assert not self.prepared, (
            "Cannot change the response status code after "
            "the headers have been sent")
        self._status = int(status)
        if reason is None:
            try:
                reason = _RESPONSES[self._status][0]
            except Exception:
                reason = ""
        self._reason = reason

    @property
    def keep_alive(self) -> Optional[bool]:
        return self._keep_alive

    def force_close(self) -> None:
        self._keep_alive = False

    @property
    def body_length(self) -> int:
        return self._body_length

    @property
    def output_length(self) -> int:
        warnings.warn("output_length is deprecated", DeprecationWarning)
        assert self._payload_writer
        return self._payload_writer.buffer_size

    def enable_chunked_encoding(self,
                                chunk_size: Optional[int] = None) -> None:
        """Enables automatic chunked transfer encoding."""
        self._chunked = True

        if hdrs.CONTENT_LENGTH in self._headers:
            raise RuntimeError("You can't enable chunked encoding when "
                               "a content length is set")
        if chunk_size is not None:
            warnings.warn("Chunk size is deprecated #1615", DeprecationWarning)

    def enable_compression(self,
                           force: Optional[Union[bool, ContentCoding]] = None
                           ) -> None:
        """Enables response compression encoding."""
        # Backwards compatibility for when force was a bool <0.17.
        if type(force) == bool:
            force = ContentCoding.deflate if force else ContentCoding.identity
            warnings.warn("Using boolean for force is deprecated #3318",
                          DeprecationWarning)
        elif force is not None:
            assert isinstance(force, ContentCoding), ("force should one of "
                                                      "None, bool or "
                                                      "ContentEncoding")

        self._compression = True
        self._compression_force = force

    @property
    def headers(self) -> "CIMultiDict[str]":
        return self._headers

    @property
    def cookies(self) -> "SimpleCookie[str]":
        return self._cookies

    def set_cookie(self,
                   name: str,
                   value: str,
                   *,
                   expires: Optional[str] = None,
                   domain: Optional[str] = None,
                   max_age: Optional[Union[int, str]] = None,
                   path: str = "/",
                   secure: Optional[bool] = None,
                   httponly: Optional[bool] = None,
                   version: Optional[str] = None,
                   samesite: Optional[str] = None) -> None:
        """Set or update response cookie.

        Sets new cookie or updates existent with new value.
        Also updates only those params which are not None.
        """

        old = self._cookies.get(name)
        if old is not None and old.coded_value == "":
            # deleted cookie
            self._cookies.pop(name, None)

        self._cookies[name] = value
        c = self._cookies[name]

        if expires is not None:
            c["expires"] = expires
        elif c.get("expires") == "Thu, 01 Jan 1970 00:00:00 GMT":
            del c["expires"]

        if domain is not None:
            c["domain"] = domain

        if max_age is not None:
            c["max-age"] = str(max_age)
        elif "max-age" in c:
            del c["max-age"]

        c["path"] = path

        if secure is not None:
            c["secure"] = secure
        if httponly is not None:
            c["httponly"] = httponly
        if version is not None:
            c["version"] = version
        if samesite is not None:
            c["samesite"] = samesite

    def del_cookie(self,
                   name: str,
                   *,
                   domain: Optional[str] = None,
                   path: str = "/") -> None:
        """Delete cookie.

        Creates new empty expired cookie.
        """
        # TODO: do we need domain/path here?
        self._cookies.pop(name, None)
        self.set_cookie(
            name,
            "",
            max_age=0,
            expires="Thu, 01 Jan 1970 00:00:00 GMT",
            domain=domain,
            path=path,
        )

    @property
    def content_length(self) -> Optional[int]:
        # Just a placeholder for adding setter
        return super().content_length

    @content_length.setter
    def content_length(self, value: Optional[int]) -> None:
        if value is not None:
            value = int(value)
            if self._chunked:
                raise RuntimeError("You can't set content length when "
                                   "chunked encoding is enable")
            self._headers[hdrs.CONTENT_LENGTH] = str(value)
        else:
            self._headers.pop(hdrs.CONTENT_LENGTH, None)

    @property
    def content_type(self) -> str:
        # Just a placeholder for adding setter
        return super().content_type

    @content_type.setter
    def content_type(self, value: str) -> None:
        self.content_type  # read header values if needed
        self._content_type = str(value)
        self._generate_content_type_header()

    @property
    def charset(self) -> Optional[str]:
        # Just a placeholder for adding setter
        return super().charset

    @charset.setter
    def charset(self, value: Optional[str]) -> None:
        ctype = self.content_type  # read header values if needed
        if ctype == "application/octet-stream":
            raise RuntimeError("Setting charset for application/octet-stream "
                               "doesn't make sense, setup content_type first")
        assert self._content_dict is not None
        if value is None:
            self._content_dict.pop("charset", None)
        else:
            self._content_dict["charset"] = str(value).lower()
        self._generate_content_type_header()

    @property
    def last_modified(self) -> Optional[datetime.datetime]:
        """The value of Last-Modified HTTP header, or None.

        This header is represented as a `datetime` object.
        """
        httpdate = self._headers.get(hdrs.LAST_MODIFIED)
        if httpdate is not None:
            timetuple = parsedate(httpdate)
            if timetuple is not None:
                return datetime.datetime(*timetuple[:6],
                                         tzinfo=datetime.timezone.utc)
        return None

    @last_modified.setter
    def last_modified(
            self, value: Optional[Union[int, float, datetime.datetime,
                                        str]]) -> None:
        if value is None:
            self._headers.pop(hdrs.LAST_MODIFIED, None)
        elif isinstance(value, (int, float)):
            self._headers[hdrs.LAST_MODIFIED] = time.strftime(
                "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value)))
        elif isinstance(value, datetime.datetime):
            self._headers[hdrs.LAST_MODIFIED] = time.strftime(
                "%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple())
        elif isinstance(value, str):
            self._headers[hdrs.LAST_MODIFIED] = value

    def _generate_content_type_header(self,
                                      CONTENT_TYPE: istr = hdrs.CONTENT_TYPE
                                      ) -> None:
        assert self._content_dict is not None
        assert self._content_type is not None
        params = "; ".join("{}={}".format(k, v)
                           for k, v in self._content_dict.items())
        if params:
            ctype = self._content_type + "; " + params
        else:
            ctype = self._content_type
        self._headers[CONTENT_TYPE] = ctype

    async def _do_start_compression(self, coding: ContentCoding) -> None:
        if coding != ContentCoding.identity:
            assert self._payload_writer is not None
            self._headers[hdrs.CONTENT_ENCODING] = coding.value
            self._payload_writer.enable_compression(coding.value)
            # Compressed payload may have different content length,
            # remove the header
            self._headers.popall(hdrs.CONTENT_LENGTH, None)

    async def _start_compression(self, request: "BaseRequest") -> None:
        if self._compression_force:
            await self._do_start_compression(self._compression_force)
        else:
            accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING,
                                                  "").lower()
            for coding in ContentCoding:
                if coding.value in accept_encoding:
                    await self._do_start_compression(coding)
                    return

    async def prepare(
            self, request: "BaseRequest") -> Optional[AbstractStreamWriter]:
        if self._eof_sent:
            return None
        if self._payload_writer is not None:
            return self._payload_writer

        return await self._start(request)

    async def _start(self, request: "BaseRequest") -> AbstractStreamWriter:
        self._req = request
        writer = self._payload_writer = request._payload_writer

        await self._prepare_headers()
        await request._prepare_hook(self)
        await self._write_headers()

        return writer

    async def _prepare_headers(self) -> None:
        request = self._req
        assert request is not None
        writer = self._payload_writer
        assert writer is not None
        keep_alive = self._keep_alive
        if keep_alive is None:
            keep_alive = request.keep_alive
        self._keep_alive = keep_alive

        version = request.version

        headers = self._headers
        for cookie in self._cookies.values():
            value = cookie.output(header="")[1:]
            headers.add(hdrs.SET_COOKIE, value)

        if self._compression:
            await self._start_compression(request)

        if self._chunked:
            if version != HttpVersion11:
                raise RuntimeError("Using chunked encoding is forbidden "
                                   "for HTTP/{0.major}.{0.minor}".format(
                                       request.version))
            writer.enable_chunking()
            headers[hdrs.TRANSFER_ENCODING] = "chunked"
            if hdrs.CONTENT_LENGTH in headers:
                del headers[hdrs.CONTENT_LENGTH]
        elif self._length_check:
            writer.length = self.content_length
            if writer.length is None:
                if version >= HttpVersion11:
                    writer.enable_chunking()
                    headers[hdrs.TRANSFER_ENCODING] = "chunked"
                    if hdrs.CONTENT_LENGTH in headers:
                        del headers[hdrs.CONTENT_LENGTH]
                else:
                    keep_alive = False
            # HTTP 1.1: https://tools.ietf.org/html/rfc7230#section-3.3.2
            # HTTP 1.0: https://tools.ietf.org/html/rfc1945#section-10.4
            elif version >= HttpVersion11 and self.status in (100, 101, 102,
                                                              103, 204):
                del headers[hdrs.CONTENT_LENGTH]

        headers.setdefault(hdrs.CONTENT_TYPE, "application/octet-stream")
        headers.setdefault(hdrs.DATE, rfc822_formatted_time())
        headers.setdefault(hdrs.SERVER, SERVER_SOFTWARE)

        # connection header
        if hdrs.CONNECTION not in headers:
            if keep_alive:
                if version == HttpVersion10:
                    headers[hdrs.CONNECTION] = "keep-alive"
            else:
                if version == HttpVersion11:
                    headers[hdrs.CONNECTION] = "close"

    async def _write_headers(self) -> None:
        request = self._req
        assert request is not None
        writer = self._payload_writer
        assert writer is not None
        # status line
        version = request.version
        status_line = "HTTP/{}.{} {} {}".format(version[0], version[1],
                                                self._status, self._reason)
        await writer.write_headers(status_line, self._headers)

    async def write(self, data: bytes) -> None:
        assert isinstance(
            data,
            (bytes, bytearray,
             memoryview)), "data argument must be byte-ish (%r)" % type(data)

        if self._eof_sent:
            raise RuntimeError("Cannot call write() after write_eof()")
        if self._payload_writer is None:
            raise RuntimeError("Cannot call write() before prepare()")

        await self._payload_writer.write(data)

    async def drain(self) -> None:
        assert not self._eof_sent, "EOF has already been sent"
        assert self._payload_writer is not None, "Response has not been started"
        warnings.warn(
            "drain method is deprecated, use await resp.write()",
            DeprecationWarning,
            stacklevel=2,
        )
        await self._payload_writer.drain()

    async def write_eof(self, data: bytes = b"") -> None:
        assert isinstance(
            data,
            (bytes, bytearray,
             memoryview)), "data argument must be byte-ish (%r)" % type(data)

        if self._eof_sent:
            return

        assert self._payload_writer is not None, "Response has not been started"

        await self._payload_writer.write_eof(data)
        self._eof_sent = True
        self._req = None
        self._body_length = self._payload_writer.output_size
        self._payload_writer = None

    def __repr__(self) -> str:
        if self._eof_sent:
            info = "eof"
        elif self.prepared:
            assert self._req is not None
            info = "{} {} ".format(self._req.method, self._req.path)
        else:
            info = "not prepared"
        return "<{} {} {}>".format(self.__class__.__name__, self.reason, info)

    def __getitem__(self, key: str) -> Any:
        return self._state[key]

    def __setitem__(self, key: str, value: Any) -> None:
        self._state[key] = value

    def __delitem__(self, key: str) -> None:
        del self._state[key]

    def __len__(self) -> int:
        return len(self._state)

    def __iter__(self) -> Iterator[str]:
        return iter(self._state)

    def __hash__(self) -> int:
        return hash(id(self))

    def __eq__(self, other: object) -> bool:
        return self is other
Example #5
0
async def test_push (golden):
    def checkWarcinfoId (headers):
        if lastWarcinfoRecordid is not None:
            assert headers['WARC-Warcinfo-ID'] == lastWarcinfoRecordid

    lastWarcinfoRecordid = None

    # null logger
    logger = Logger ()
    with open('/tmp/test.warc.gz', 'w+b') as fd:
        with WarcHandler (fd, logger) as handler:
            for g in golden:
                await handler.push (g)

        fd.seek (0)
        it = iter (ArchiveIterator (fd))
        for g in golden:
            if isinstance (g, ControllerStart):
                rec = next (it)

                headers = rec.rec_headers
                assert headers['warc-type'] == 'warcinfo'
                assert 'warc-target-uri' not in headers
                assert 'x-crocoite-type' not in headers

                data = json.load (rec.raw_stream)
                assert data == g.payload

                lastWarcinfoRecordid = headers['warc-record-id']
                assert lastWarcinfoRecordid
            elif isinstance (g, Script):
                rec = next (it)

                headers = rec.rec_headers
                assert headers['warc-type'] == 'resource'
                assert headers['content-type'] == 'application/javascript; charset=utf-8'
                assert headers['x-crocoite-type'] == 'script'
                checkWarcinfoId (headers)
                if g.path:
                    assert URL (headers['warc-target-uri']) == URL ('file://' + g.abspath)
                else:
                    assert 'warc-target-uri' not in headers

                data = rec.raw_stream.read ().decode ('utf-8')
                assert data == g.data
            elif isinstance (g, ScreenshotEvent):
                # XXX: check refers-to header
                rec = next (it)

                headers = rec.rec_headers
                assert headers['warc-type'] == 'conversion'
                assert headers['x-crocoite-type'] == 'screenshot'
                checkWarcinfoId (headers)
                assert URL (headers['warc-target-uri']) == g.url, (headers['warc-target-uri'], g.url)
                assert headers['warc-refers-to'] is None
                assert int (headers['X-Crocoite-Screenshot-Y-Offset']) == g.yoff

                assert rec.raw_stream.read () == g.data
            elif isinstance (g, DomSnapshotEvent):
                rec = next (it)

                headers = rec.rec_headers
                assert headers['warc-type'] == 'conversion'
                assert headers['x-crocoite-type'] == 'dom-snapshot'
                checkWarcinfoId (headers)
                assert URL (headers['warc-target-uri']) == g.url
                assert headers['warc-refers-to'] is None

                assert rec.raw_stream.read () == g.document
            elif isinstance (g, RequestResponsePair):
                rec = next (it)

                # request
                headers = rec.rec_headers
                assert headers['warc-type'] == 'request'
                assert 'x-crocoite-type' not in headers
                checkWarcinfoId (headers)
                assert URL (headers['warc-target-uri']) == g.url
                assert headers['x-chrome-request-id'] == g.id
                
                assert CIMultiDict (rec.http_headers.headers) == g.request.headers
                if g.request.hasPostData:
                    if g.request.body is not None:
                        assert rec.raw_stream.read () == g.request.body
                    else:
                        # body fetch failed
                        assert headers['warc-truncated'] == 'unspecified'
                        assert not rec.raw_stream.read ()
                else:
                    assert not rec.raw_stream.read ()

                # response
                if g.response:
                    rec = next (it)
                    headers = rec.rec_headers
                    httpheaders = rec.http_headers
                    assert headers['warc-type'] == 'response'
                    checkWarcinfoId (headers)
                    assert URL (headers['warc-target-uri']) == g.url
                    assert headers['x-chrome-request-id'] == g.id
                    assert 'x-crocoite-type' not in headers

                    # these are checked separately
                    filteredHeaders = CIMultiDict (httpheaders.headers)
                    for b in {'content-type', 'content-length'}:
                        if b in g.response.headers:
                            g.response.headers.popall (b)
                        if b in filteredHeaders:
                            filteredHeaders.popall (b)
                    assert filteredHeaders == g.response.headers

                    expectedContentType = g.response.mimeType
                    if expectedContentType is not None:
                        assert httpheaders['content-type'].startswith (expectedContentType)

                    if g.response.body is not None:
                        assert rec.raw_stream.read () == g.response.body
                        assert httpheaders['content-length'] == str (len (g.response.body))
                        # body is never truncated if it exists
                        assert headers['warc-truncated'] is None

                        # unencoded strings are converted to utf8
                        if isinstance (g.response.body, UnicodeBody) and httpheaders['content-type'] is not None:
                            assert httpheaders['content-type'].endswith ('; charset=utf-8')
                    else:
                        # body fetch failed
                        assert headers['warc-truncated'] == 'unspecified'
                        assert not rec.raw_stream.read ()
                        # content-length header should be kept intact
            else:
                assert False, f"invalid golden type {type(g)}" # pragma: no cover

        # no further records
        with pytest.raises (StopIteration):
            next (it)
Example #6
0
class StreamResponse(BaseClass, HeadersMixin):

    __slots__ = ('_length_check', '_body', '_keep_alive', '_chunked',
                 '_compression', '_compression_force', '_cookies', '_req',
                 '_payload_writer', '_eof_sent', '_body_length', '_state',
                 '_headers', '_status', '_reason', '__weakref__')

    def __init__(self,
                 *,
                 status: int = 200,
                 reason: Optional[str] = None,
                 headers: Optional[LooseHeaders] = None) -> None:
        super().__init__()
        self._length_check = True
        self._body = None
        self._keep_alive = None  # type: Optional[bool]
        self._chunked = False
        self._compression = False
        self._compression_force = None  # type: Optional[ContentCoding]
        self._cookies = SimpleCookie()  # type: SimpleCookie[str]

        self._req = None  # type: Optional[BaseRequest]
        self._payload_writer = None  # type: Optional[AbstractStreamWriter]
        self._eof_sent = False
        self._body_length = 0
        self._state = {}  # type: Dict[str, Any]

        if headers is not None:
            self._headers = CIMultiDict(headers)  # type: CIMultiDict[str]
        else:
            self._headers = CIMultiDict()

        self.set_status(status, reason)

    @property
    def prepared(self) -> bool:
        return self._payload_writer is not None

    @property
    def task(self) -> 'asyncio.Task[None]':
        return getattr(self._req, 'task', None)

    @property
    def status(self) -> int:
        return self._status

    @property
    def chunked(self) -> bool:
        return self._chunked

    @property
    def compression(self) -> bool:
        return self._compression

    @property
    def reason(self) -> str:
        return self._reason

    def set_status(
            self,
            status: int,
            reason: Optional[str] = None,
            _RESPONSES: Mapping[int, Tuple[str, str]] = RESPONSES) -> None:
        assert not self.prepared, \
            'Cannot change the response status code after ' \
            'the headers have been sent'
        self._status = int(status)
        if reason is None:
            try:
                reason = _RESPONSES[self._status][0]
            except Exception:
                reason = ''
        self._reason = reason

    @property
    def keep_alive(self) -> Optional[bool]:
        return self._keep_alive

    def force_close(self) -> None:
        self._keep_alive = False

    @property
    def body_length(self) -> int:
        return self._body_length

    def enable_chunked_encoding(self) -> None:
        """Enables automatic chunked transfer encoding."""
        self._chunked = True

        if hdrs.CONTENT_LENGTH in self._headers:
            raise RuntimeError("You can't enable chunked encoding when "
                               "a content length is set")

    def enable_compression(self,
                           force: Optional[ContentCoding] = None) -> None:
        """Enables response compression encoding."""
        # Backwards compatibility for when force was a bool <0.17.
        self._compression = True
        self._compression_force = force

    @property
    def headers(self) -> 'CIMultiDict[str]':
        return self._headers

    @property
    def cookies(self) -> 'SimpleCookie[str]':
        return self._cookies

    def set_cookie(self,
                   name: str,
                   value: str,
                   *,
                   expires: Optional[str] = None,
                   domain: Optional[str] = None,
                   max_age: Optional[Union[int, str]] = None,
                   path: str = '/',
                   secure: Optional[bool] = None,
                   httponly: Optional[bool] = None,
                   version: Optional[str] = None,
                   samesite: Optional[str] = None) -> None:
        """Set or update response cookie.

        Sets new cookie or updates existent with new value.
        Also updates only those params which are not None.
        """

        old = self._cookies.get(name)
        if old is not None and old.coded_value == '':
            # deleted cookie
            self._cookies.pop(name, None)

        self._cookies[name] = value
        c = self._cookies[name]

        if expires is not None:
            c['expires'] = expires
        elif c.get('expires') == 'Thu, 01 Jan 1970 00:00:00 GMT':
            del c['expires']

        if domain is not None:
            c['domain'] = domain

        if max_age is not None:
            c['max-age'] = str(max_age)
        elif 'max-age' in c:
            del c['max-age']

        c['path'] = path

        if secure is not None:
            c['secure'] = secure
        if httponly is not None:
            c['httponly'] = httponly
        if version is not None:
            c['version'] = version
        if samesite is not None:
            c['samesite'] = samesite

    def del_cookie(self,
                   name: str,
                   *,
                   domain: Optional[str] = None,
                   path: str = '/') -> None:
        """Delete cookie.

        Creates new empty expired cookie.
        """
        # TODO: do we need domain/path here?
        self._cookies.pop(name, None)
        self.set_cookie(name,
                        '',
                        max_age=0,
                        expires="Thu, 01 Jan 1970 00:00:00 GMT",
                        domain=domain,
                        path=path)

    @property
    def content_length(self) -> Optional[int]:
        # Just a placeholder for adding setter
        return super().content_length

    @content_length.setter
    def content_length(self, value: Optional[int]) -> None:
        if value is not None:
            value = int(value)
            if self._chunked:
                raise RuntimeError("You can't set content length when "
                                   "chunked encoding is enable")
            self._headers[hdrs.CONTENT_LENGTH] = str(value)
        else:
            self._headers.pop(hdrs.CONTENT_LENGTH, None)

    @property
    def content_type(self) -> str:
        # Just a placeholder for adding setter
        return super().content_type

    @content_type.setter
    def content_type(self, value: str) -> None:
        self.content_type  # read header values if needed
        self._content_type = str(value)
        self._generate_content_type_header()

    @property
    def charset(self) -> Optional[str]:
        # Just a placeholder for adding setter
        return super().charset

    @charset.setter
    def charset(self, value: Optional[str]) -> None:
        ctype = self.content_type  # read header values if needed
        if ctype == 'application/octet-stream':
            raise RuntimeError("Setting charset for application/octet-stream "
                               "doesn't make sense, setup content_type first")
        assert self._content_dict is not None
        if value is None:
            self._content_dict.pop('charset', None)
        else:
            self._content_dict['charset'] = str(value).lower()
        self._generate_content_type_header()

    @property
    def last_modified(self) -> Optional[datetime.datetime]:
        """The value of Last-Modified HTTP header, or None.

        This header is represented as a `datetime` object.
        """
        httpdate = self._headers.get(hdrs.LAST_MODIFIED)
        if httpdate is not None:
            timetuple = parsedate(httpdate)
            if timetuple is not None:
                return datetime.datetime(*timetuple[:6],
                                         tzinfo=datetime.timezone.utc)
        return None

    @last_modified.setter
    def last_modified(
            self, value: Optional[Union[int, float, datetime.datetime,
                                        str]]) -> None:
        if value is None:
            self._headers.pop(hdrs.LAST_MODIFIED, None)
        elif isinstance(value, (int, float)):
            self._headers[hdrs.LAST_MODIFIED] = time.strftime(
                "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value)))
        elif isinstance(value, datetime.datetime):
            self._headers[hdrs.LAST_MODIFIED] = time.strftime(
                "%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple())
        elif isinstance(value, str):
            self._headers[hdrs.LAST_MODIFIED] = value

    def _generate_content_type_header(self,
                                      CONTENT_TYPE: istr = hdrs.CONTENT_TYPE
                                      ) -> None:
        assert self._content_dict is not None
        assert self._content_type is not None
        params = '; '.join("{}={}".format(k, v)
                           for k, v in self._content_dict.items())
        if params:
            ctype = self._content_type + '; ' + params
        else:
            ctype = self._content_type
        self._headers[CONTENT_TYPE] = ctype

    async def _do_start_compression(self, coding: ContentCoding) -> None:
        if coding != ContentCoding.identity:
            assert self._payload_writer is not None
            self._headers[hdrs.CONTENT_ENCODING] = coding.value
            self._payload_writer.enable_compression(coding.value)
            # Compressed payload may have different content length,
            # remove the header
            self._headers.popall(hdrs.CONTENT_LENGTH, None)

    async def _start_compression(self, request: 'BaseRequest') -> None:
        if self._compression_force:
            await self._do_start_compression(self._compression_force)
        else:
            accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING,
                                                  '').lower()
            for coding in ContentCoding:
                if coding.value in accept_encoding:
                    await self._do_start_compression(coding)
                    return

    async def prepare(
            self, request: 'BaseRequest') -> Optional[AbstractStreamWriter]:
        if self._eof_sent:
            return None
        if self._payload_writer is not None:
            return self._payload_writer

        await request._prepare_hook(self)
        return await self._start(request)

    async def _start(self, request: 'BaseRequest') -> AbstractStreamWriter:
        self._req = request

        keep_alive = self._keep_alive
        if keep_alive is None:
            keep_alive = request.keep_alive
        self._keep_alive = keep_alive

        version = request.version
        writer = self._payload_writer = request._payload_writer

        headers = self._headers
        for cookie in self._cookies.values():
            value = cookie.output(header='')[1:]
            headers.add(hdrs.SET_COOKIE, value)

        if self._compression:
            await self._start_compression(request)

        if self._chunked:
            if version != HttpVersion11:
                raise RuntimeError("Using chunked encoding is forbidden "
                                   "for HTTP/{0.major}.{0.minor}".format(
                                       request.version))
            writer.enable_chunking()
            headers[hdrs.TRANSFER_ENCODING] = 'chunked'
            if hdrs.CONTENT_LENGTH in headers:
                del headers[hdrs.CONTENT_LENGTH]
        elif self._length_check:
            writer.length = self.content_length
            if writer.length is None:
                if version >= HttpVersion11:
                    writer.enable_chunking()
                    headers[hdrs.TRANSFER_ENCODING] = 'chunked'
                    if hdrs.CONTENT_LENGTH in headers:
                        del headers[hdrs.CONTENT_LENGTH]
                else:
                    keep_alive = False

        headers.setdefault(hdrs.CONTENT_TYPE, 'application/octet-stream')
        headers.setdefault(hdrs.DATE, rfc822_formatted_time())
        headers.setdefault(hdrs.SERVER, 'nginx')

        # connection header
        if hdrs.CONNECTION not in headers:
            if keep_alive:
                if version == HttpVersion10:
                    headers[hdrs.CONNECTION] = 'keep-alive'
            else:
                if version == HttpVersion11:
                    headers[hdrs.CONNECTION] = 'close'

        # status line
        status_line = 'HTTP/{}.{} {} {}'.format(version[0], version[1],
                                                self._status, self._reason)
        await writer.write_headers(status_line, headers)

        return writer

    async def write(self, data: bytes) -> None:
        assert isinstance(data, (bytes, bytearray, memoryview)), \
            "data argument must be byte-ish (%r)" % type(data)

        if self._eof_sent:
            raise RuntimeError("Cannot call write() after write_eof()")
        if self._payload_writer is None:
            raise RuntimeError("Cannot call write() before prepare()")

        await self._payload_writer.write(data)

    async def drain(self) -> None:
        assert not self._eof_sent, "EOF has already been sent"
        assert self._payload_writer is not None, \
            "Response has not been started"
        warnings.warn("drain method is deprecated, use await resp.write()",
                      DeprecationWarning,
                      stacklevel=2)
        await self._payload_writer.drain()

    async def write_eof(self, data: bytes = b'') -> None:
        assert isinstance(data, (bytes, bytearray, memoryview)), \
            "data argument must be byte-ish (%r)" % type(data)

        if self._eof_sent:
            return

        assert self._payload_writer is not None, \
            "Response has not been started"

        await self._payload_writer.write_eof(data)
        self._eof_sent = True
        self._req = None
        self._body_length = self._payload_writer.output_size
        self._payload_writer = None

    def __repr__(self) -> str:
        if self._eof_sent:
            info = "eof"
        elif self.prepared:
            assert self._req is not None
            info = "{} {} ".format(self._req.method, self._req.path)
        else:
            info = "not prepared"
        return "<{} {} {}>".format(self.__class__.__name__, self.reason, info)

    def __getitem__(self, key: str) -> Any:
        return self._state[key]

    def __setitem__(self, key: str, value: Any) -> None:
        self._state[key] = value

    def __delitem__(self, key: str) -> None:
        del self._state[key]

    def __len__(self) -> int:
        return len(self._state)

    def __iter__(self) -> Iterator[str]:
        return iter(self._state)

    def __hash__(self) -> int:
        return hash(id(self))

    def __eq__(self, other: object) -> bool:
        return self is other