Ejemplo n.º 1
0
 def get_data(self, unreader, buf, stop=False):
     data = unreader.read()
     if not data:
         if stop:
             raise StopIteration()
         raise NoMoreData(buf.getvalue())
     buf.write(data)
Ejemplo n.º 2
0
 def parse_chunked(self, unreader):
     (size, rest) = self.parse_chunk_size(unreader)
     while size > 0:
         while size > len(rest):
             size -= len(rest)
             yield rest
             rest = unreader.read()
             if not rest:
                 raise NoMoreData()
         yield rest[:size]
         # Remove \r\n after chunk
         rest = rest[size:]
         while len(rest) < 2:
             rest += unreader.read()
         if rest[:2] != b'\r\n':
             raise ChunkMissingTerminator(rest[:2])
         (size, rest) = self.parse_chunk_size(unreader, data=rest[2:])
Ejemplo n.º 3
0
    def parse_chunked(self, unreader):
        size, rest = self.parse_chunk_size(unreader)
        while size > 0:
            # 读取一个chunk中的数据直到差不多要读取完了
            while size > len(rest):
                size -= len(rest)
                yield rest
                rest = unreader.read()
                if not rest:
                    raise NoMoreData()
            # 读取一个chunk中最后剩余的数据
            yield rest[:size]
            # Remove \r\n after chunk
            rest = rest[size:]
            while len(rest) < 2:
                rest += unreader.read()
            # 上一个chunk的结尾必然是 \r\n
            if rest[:2] != b'\r\n':
                raise ChunkMissingTerminator(rest[:2])

            # 读取下一个chunk的数据
            size, rest = self.parse_chunk_size(unreader, data=rest[2:])
Ejemplo n.º 4
0
 def get_data(self, unreader, buf):
     data = unreader.read()
     if not data:
         raise NoMoreData()
     buf.write(data)
Ejemplo n.º 5
0
    def __call__(self, environ: Environ,
                 start_response: StartResponse) -> Iterable[bytes]:
        request = Request.from_environ(environ)
        resolver = self.injector.get_resolver({
            Environ:
            environ,
            Headers:
            request.headers,
            Host:
            Host(request.host),
            Method:
            Method(request.method),
            Port:
            Port(request.port),
            QueryParams:
            request.params,
            QueryString:
            QueryString(environ.get("QUERY_STRING", "")),
            Request:
            request,
            RequestInput:
            RequestInput(request.body_file),
            Scheme:
            Scheme(request.scheme),
            StartResponse:
            start_response,
        })

        try:
            handler: Callable[..., Any]
            route_and_params = self.router.match(request.method, request.path)
            if route_and_params is not None:
                route, params = route_and_params
                handler = route.handler
                resolver.add_component(RouteComponent(route))
                resolver.add_component(RouteParamsComponent(params))
            else:
                handler = self.handle_404
                resolver.add_component(RouteComponent(None))

            handler = resolver.resolve(handler)
            for middleware in reversed(self.middleware):
                handler = resolver.resolve(middleware(handler))

            exc_info = None
            response = handler()
        except RequestHandled:
            # This is used to break out of gunicorn's keep-alive loop.
            # If we don't do this, then gunicorn might attempt to read
            # from a closed socket.
            raise NoMoreData()
        except RequestParserNotAvailable:
            exc_info = None
            response = resolver.resolve(self.handle_415)()
        except ParseError as e:
            exc_info = None
            response = resolver.resolve(self.handle_parse_error)(exception=e)
        except Exception as e:
            exc_info = sys.exc_info()
            response = resolver.resolve(self.handle_exception)(exception=e)

        content_length = response.get_content_length()
        if content_length is not None:
            response.headers.add("content-length", str(content_length))

        start_response(response.status, list(response.headers), exc_info)
        if response.status != HTTP_204:
            wrapper = environ.get("wsgi.file_wrapper", FileWrapper)
            return wrapper(response.stream)
        else:
            return []
Ejemplo n.º 6
0
    def parse_headers(self, data):
        cfg = self.cfg
        headers = []

        # Split lines on \r\n keeping the \r\n on each line
        lines = [bytes_to_str(line) + "\r\n" for line in data.split(b"\r\n")]

        # handle scheme headers
        scheme_header = False
        secure_scheme_headers = {}
        if '*' in cfg.forwarded_allow_ips:
            secure_scheme_headers = cfg.secure_scheme_headers
        elif isinstance(self.unreader, SocketUnreader):
            try:
                remote_addr = self.unreader.sock.getpeername()
                if self.unreader.sock.family in (socket.AF_INET,
                                                 socket.AF_INET6):
                    remote_host = remote_addr[0]
                    if remote_host in cfg.forwarded_allow_ips:
                        secure_scheme_headers = cfg.secure_scheme_headers
                elif self.unreader.sock.family == socket.AF_UNIX:
                    secure_scheme_headers = cfg.secure_scheme_headers
            except OSError:
                raise NoMoreData()

        # Parse headers into key/value pairs paying attention
        # to continuation lines.
        while lines:
            if len(headers) >= self.limit_request_fields:
                raise LimitRequestHeaders("limit request headers fields")

            # Parse initial header name : value pair.
            curr = lines.pop(0)
            header_length = len(curr)
            if curr.find(":") < 0:
                raise InvalidHeader(curr.strip())
            name, value = curr.split(":", 1)
            if self.cfg.strip_header_spaces:
                name = name.rstrip(" \t").upper()
            else:
                name = name.upper()
            if HEADER_RE.search(name):
                raise InvalidHeaderName(name)

            name, value = name.strip(), [value.lstrip()]

            # Consume value continuation lines
            while lines and lines[0].startswith((" ", "\t")):
                curr = lines.pop(0)
                header_length += len(curr)
                if header_length > self.limit_request_field_size > 0:
                    raise LimitRequestHeaders("limit request headers "
                                              "fields size")
                value.append(curr)
            value = ''.join(value).rstrip()

            if header_length > self.limit_request_field_size > 0:
                raise LimitRequestHeaders("limit request headers fields size")

            if name in secure_scheme_headers:
                secure = value == secure_scheme_headers[name]
                scheme = "https" if secure else "http"
                if scheme_header:
                    if scheme != self.scheme:
                        raise InvalidSchemeHeaders()
                else:
                    scheme_header = True
                    self.scheme = scheme

            headers.append((name, value))

        return headers