Exemplo n.º 1
0
class RestApp:
    def __init__(self, *, app_name, sentry_config):
        self.app_name = app_name
        self.app_request_metric = "{}_request".format(app_name)
        self.app = aiohttp.web.Application()
        self.app.on_startup.append(self.create_http_client)
        self.app.on_cleanup.append(self.cleanup_http_client)
        self.http_client_v = None
        self.http_client_no_v = None
        self.log = logging.getLogger(self.app_name)
        self.stats = StatsClient(sentry_config=sentry_config)
        self.raven_client = self.stats.raven_client
        self.app.on_cleanup.append(self.cleanup_stats_client)

    async def cleanup_stats_client(self, app):  # pylint: disable=unused-argument
        self.stats.close()

    async def create_http_client(self, app):  # pylint: disable=unused-argument
        no_v_conn = aiohttp.TCPConnector(ssl=False)
        self.http_client_no_v = aiohttp.ClientSession(connector=no_v_conn, headers={"User-Agent": SERVER_NAME})
        self.http_client_v = aiohttp.ClientSession(headers={"User-Agent": SERVER_NAME})

    async def cleanup_http_client(self, app):  # pylint: disable=unused-argument
        if self.http_client_no_v:
            await self.http_client_no_v.close()
        if self.http_client_v:
            await self.http_client_v.close()

    @staticmethod
    def cors_and_server_headers_for_request(*, request, origin="*"):  # pylint: disable=unused-argument
        return {
            "Access-Control-Allow-Origin": origin,
            "Access-Control-Allow-Methods": "DELETE, GET, OPTIONS, POST, PUT",
            "Access-Control-Allow-Headers": "Authorization, Content-Type",
            "Server": SERVER_NAME,
        }

    def check_rest_headers(self, request: HTTPRequest) -> dict:  # pylint:disable=inconsistent-return-statements
        method = request.method
        default_content = "application/vnd.kafka.json.v2+json"
        default_accept = "*/*"
        result: dict = {"content_type": default_content}
        content_matcher = REST_CONTENT_TYPE_RE.search(
            cgi.parse_header(request.get_header("Content-Type", default_content))[0]
        )
        accept_matcher = REST_ACCEPT_RE.search(cgi.parse_header(request.get_header("Accept", default_accept))[0])
        if method in {"POST", "PUT"}:
            if not content_matcher:
                http_error(
                    message="HTTP 415 Unsupported Media Type",
                    content_type=result["content_type"],
                    code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
                )
        if content_matcher and accept_matcher:
            header_info = content_matcher.groupdict()
            header_info["embedded_format"] = header_info.get("embedded_format") or "binary"
            result["requests"] = header_info
            result["accepts"] = accept_matcher.groupdict()
            return result
        self.log.error("Not acceptable: %r", request.get_header("accept"))
        http_error(
            message="HTTP 406 Not Acceptable",
            content_type=result["content_type"],
            code=HTTPStatus.NOT_ACCEPTABLE,
        )

    def check_schema_headers(self, request: HTTPRequest):
        method = request.method
        response_default_content_type = "application/vnd.schemaregistry.v1+json"
        content_type = request.get_header("Content-Type", JSON_CONTENT_TYPE)

        if method in {"POST", "PUT"} and cgi.parse_header(content_type)[0] not in SCHEMA_CONTENT_TYPES:
            http_error(
                message="HTTP 415 Unsupported Media Type",
                content_type=response_default_content_type,
                code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
            )
        accept_val = request.get_header("Accept")
        if accept_val:
            if accept_val in ("*/*", "*") or accept_val.startswith("*/"):
                return response_default_content_type
            content_type_match = get_best_match(accept_val, SCHEMA_ACCEPT_VALUES)
            if not content_type_match:
                self.log.debug("Unexpected Accept value: %r", accept_val)
                http_error(
                    message="HTTP 406 Not Acceptable",
                    content_type=response_default_content_type,
                    code=HTTPStatus.NOT_ACCEPTABLE,
                )
            return content_type_match
        return response_default_content_type

    async def _handle_request(
        self,
        *,
        request,
        path_for_stats,
        callback,
        schema_request=False,
        callback_with_request=False,
        json_request=False,
        rest_request=False
    ):
        start_time = time.monotonic()
        resp = None
        rapu_request = HTTPRequest(
            headers=request.headers,
            query=request.query,
            method=request.method,
            url=request.url,
            path_for_stats=path_for_stats,
        )
        try:
            if request.method == "OPTIONS":
                origin = request.headers.get("Origin")
                if not origin:
                    raise HTTPResponse(body="OPTIONS missing Origin", status=HTTPStatus.BAD_REQUEST)
                headers = self.cors_and_server_headers_for_request(request=rapu_request, origin=origin)
                raise HTTPResponse(body=b"", status=HTTPStatus.OK, headers=headers)

            body = await request.read()
            if json_request:
                if not body:
                    raise HTTPResponse(body="Missing request JSON body", status=HTTPStatus.BAD_REQUEST)
                try:
                    _, options = cgi.parse_header(rapu_request.get_header("Content-Type"))
                    charset = options.get("charset", "utf-8")
                    body_string = body.decode(charset)
                    rapu_request.json = jsonlib.loads(body_string)
                except jsonlib.decoder.JSONDecodeError:
                    raise HTTPResponse(body="Invalid request JSON body", status=HTTPStatus.BAD_REQUEST)
                except UnicodeDecodeError:
                    raise HTTPResponse(body=f"Request body is not valid {charset}", status=HTTPStatus.BAD_REQUEST)
                except LookupError:
                    raise HTTPResponse(body=f"Unknown charset {charset}", status=HTTPStatus.BAD_REQUEST)
            else:
                if body not in {b"", b"{}"}:
                    raise HTTPResponse(body="No request body allowed for this operation", status=HTTPStatus.BAD_REQUEST)

            callback_kwargs = dict(request.match_info)
            if callback_with_request:
                callback_kwargs["request"] = rapu_request

            if rest_request:
                params = self.check_rest_headers(rapu_request)
                if "requests" in params:
                    rapu_request.content_type = params["requests"]
                    params.pop("requests")
                if "accepts" in params:
                    rapu_request.accepts = params["accepts"]
                    params.pop("accepts")
                callback_kwargs.update(params)

            if schema_request:
                content_type = self.check_schema_headers(rapu_request)
                callback_kwargs["content_type"] = content_type

            try:
                data = await callback(**callback_kwargs)
                status = HTTPStatus.OK
                headers = {}
            except HTTPResponse as ex:
                data = ex.body
                status = ex.status
                headers = ex.headers
            except:  # pylint: disable=bare-except
                self.log.exception("Internal server error")
                data = {"error_code": HTTPStatus.INTERNAL_SERVER_ERROR.value, "message": "Internal server error"}
                status = HTTPStatus.INTERNAL_SERVER_ERROR
                headers = {}
            headers.update(self.cors_and_server_headers_for_request(request=rapu_request))

            if isinstance(data, (dict, list)):
                resp_bytes = json_encode(data, binary=True, sort_keys=True, compact=True)
            elif isinstance(data, str):
                if "Content-Type" not in headers:
                    headers["Content-Type"] = "text/plain; charset=utf-8"
                resp_bytes = data.encode("utf-8")
            else:
                resp_bytes = data

            # On 204 - NO CONTENT there is no point of calculating cache headers
            if is_success(status):
                if resp_bytes:
                    etag = '"{}"'.format(hashlib.md5(resp_bytes).hexdigest())
                else:
                    etag = '""'
                if_none_match = request.headers.get("if-none-match")
                if if_none_match and if_none_match.replace("W/", "") == etag:
                    status = HTTPStatus.NOT_MODIFIED
                    resp_bytes = b""

                headers["access-control-expose-headers"] = "etag"
                headers["etag"] = etag

            resp = aiohttp.web.Response(body=resp_bytes, status=status.value, headers=headers)
        except HTTPResponse as ex:
            if isinstance(ex.body, str):
                resp = aiohttp.web.Response(text=ex.body, status=ex.status.value, headers=ex.headers)
            else:
                resp = aiohttp.web.Response(body=ex.body, status=ex.status.value, headers=ex.headers)
        except asyncio.CancelledError:
            self.log.debug("Client closed connection")
            raise
        except Exception as ex:  # pylint: disable=broad-except
            self.stats.unexpected_exception(ex=ex, where="rapu_wrapped_callback")
            self.log.exception("Unexpected error handling user request: %s %s", request.method, request.url)
            resp = aiohttp.web.Response(text="Internal Server Error", status=HTTPStatus.INTERNAL_SERVER_ERROR.value)
        finally:
            self.stats.timing(
                self.app_request_metric,
                time.monotonic() - start_time,
                tags={
                    "path": path_for_stats,
                    # no `resp` means that we had a failure in exception handler
                    "result": resp.status if resp else 0,
                    "method": request.method,
                }
            )

        return resp

    def route(self, path, *, callback, method, schema_request=False, with_request=None, json_body=None, rest_request=False):
        # pretty path for statsd reporting
        path_for_stats = re.sub(r"<[\w:]+>", "x", path)

        # bottle compatible routing
        aio_route = path
        aio_route = re.sub(r"<(\w+):path>", r"{\1:.+}", aio_route)
        aio_route = re.sub(r"<(\w+)>", r"{\1}", aio_route)

        if (method in {"POST", "PUT"}) and with_request is None:
            with_request = True

        if with_request and json_body is None:
            json_body = True

        async def wrapped_callback(request):
            return await self._handle_request(
                request=request,
                path_for_stats=path_for_stats,
                callback=callback,
                schema_request=schema_request,
                callback_with_request=with_request,
                json_request=json_body,
                rest_request=rest_request
            )

        async def wrapped_cors(request):
            return await self._handle_request(
                request=request,
                path_for_stats=path_for_stats,
                callback=None,
            )

        if not aio_route.endswith("/"):
            self.app.router.add_route(method, aio_route + "/", wrapped_callback)
            self.app.router.add_route(method, aio_route, wrapped_callback)
        else:
            self.app.router.add_route(method, aio_route, wrapped_callback)
            self.app.router.add_route(method, aio_route[:-1], wrapped_callback)
        try:
            self.app.router.add_route("OPTIONS", aio_route, wrapped_cors)
        except RuntimeError as ex:
            if "Added route will never be executed, method OPTIONS is already registered" not in str(ex):
                raise

    async def http_request(self, url, *, method="GET", json=None, timeout=10.0, verify=True, proxy=None):
        close_session = False

        if isinstance(verify, str):
            sslcontext = ssl.create_default_context(cadata=verify)
        else:
            sslcontext = None

        if proxy:
            connector = aiohttp_socks.SocksConnector(
                socks_ver=aiohttp_socks.SocksVer.SOCKS5,
                host=proxy["host"],
                port=proxy["port"],
                username=proxy["username"],
                password=proxy["password"],
                rdns=False,
                verify_ssl=verify,
                ssl_context=sslcontext,
            )
            session = aiohttp.ClientSession(connector=connector)
            close_session = True
        elif sslcontext:
            conn = aiohttp.TCPConnector(ssl_context=sslcontext)
            session = aiohttp.ClientSession(connector=conn)
            close_session = True
        elif verify is True:
            session = self.http_client_v
        elif verify is False:
            session = self.http_client_no_v
        else:
            raise ValueError("invalid arguments to http_request")

        func = getattr(session, method.lower())
        try:
            with async_timeout.timeout(timeout):
                async with func(url, json=json) as response:
                    if response.headers.get("content-type", "").startswith(JSON_CONTENT_TYPE):
                        resp_content = await response.json()
                    else:
                        resp_content = await response.text()
                    result = HTTPResponse(body=resp_content, status=HTTPStatus(response.status))
        finally:
            if close_session:
                await session.close()

        return result

    def run(self, *, host, port):
        aiohttp.web.run_app(
            app=self.app,
            host=host,
            port=port,
            access_log_format='%Tfs %{x-client-ip}i "%r" %s "%{user-agent}i" response=%bb request_body=%{content-length}ib',
        )

    def add_routes(self):
        pass  # Override in sub-classes
Exemplo n.º 2
0
class RestApp:
    def __init__(self, *, app_name, sentry_config):
        self.app_name = app_name
        self.app_request_metric = "{}_request".format(app_name)
        self.app = aiohttp.web.Application()
        self.app.on_startup.append(self.create_http_client)
        self.app.on_cleanup.append(self.cleanup_http_client)
        self.http_client_v = None
        self.http_client_no_v = None
        self.log = logging.getLogger(self.app_name)
        self.stats = StatsClient(sentry_config=sentry_config)
        self.raven_client = self.stats.raven_client
        self.app.on_cleanup.append(self.cleanup_stats_client)

    async def cleanup_stats_client(self, app):  # pylint: disable=unused-argument
        self.stats.close()

    async def create_http_client(self, app):  # pylint: disable=unused-argument
        no_v_conn = aiohttp.TCPConnector(verify_ssl=False)
        self.http_client_no_v = aiohttp.ClientSession(
            connector=no_v_conn, headers={"User-Agent": SERVER_NAME})
        self.http_client_v = aiohttp.ClientSession(
            headers={"User-Agent": SERVER_NAME})

    async def cleanup_http_client(self, app):  # pylint: disable=unused-argument
        if self.http_client_no_v:
            await self.http_client_no_v.close()
        if self.http_client_v:
            await self.http_client_v.close()

    @staticmethod
    def cors_and_server_headers_for_request(*, request, origin="*"):  # pylint: disable=unused-argument
        return {
            "Access-Control-Allow-Origin": origin,
            "Access-Control-Allow-Methods": "DELETE, GET, OPTIONS, POST, PUT",
            "Access-Control-Allow-Headers": "Authorization, Content-Type",
            "Server": SERVER_NAME,
        }

    def check_schema_headers(self, request):
        method = request.method
        headers = request.headers

        content_type = "application/vnd.schemaregistry.v1+json"
        if method in {
                "POST", "PUT"
        } and headers["Content-Type"] not in ACCEPTED_SCHEMA_CONTENT_TYPES:
            raise HTTPResponse(
                body=json_encode(
                    {
                        "error_code": 415,
                        "message": "HTTP 415 Unsupported Media Type",
                    },
                    binary=True),
                headers={"Content-Type": content_type},
                status=415,
            )

        if "Accept" in headers:
            if headers["Accept"] == "*/*" or headers["Accept"].startswith(
                    "*/"):
                return "application/vnd.schemaregistry.v1+json"
            content_type_match = get_best_match(headers["Accept"],
                                                ACCEPTED_SCHEMA_CONTENT_TYPES)
            if not content_type_match:
                self.log.debug("Unexpected Accept value: %r",
                               headers["Accept"])
                raise HTTPResponse(
                    body=json_encode(
                        {
                            "error_code": 406,
                            "message": "HTTP 406 Not Acceptable",
                        },
                        binary=True),
                    headers={"Content-Type": content_type},
                    status=406,
                )
            return content_type_match
        return content_type

    async def _handle_request(self,
                              *,
                              request,
                              path_for_stats,
                              callback,
                              schema_request=False,
                              callback_with_request=False,
                              json_request=False):
        start_time = time.monotonic()
        resp = None
        rapu_request = HTTPRequest(
            headers=request.headers,
            query=request.query,
            method=request.method,
            url=request.url,
            path_for_stats=path_for_stats,
        )
        try:
            if request.method == "OPTIONS":
                origin = request.headers.get("Origin")
                if not origin:
                    raise HTTPResponse(body="OPTIONS missing Origin",
                                       status=400)
                headers = self.cors_and_server_headers_for_request(
                    request=rapu_request, origin=origin)
                raise HTTPResponse(body=b"", status=200, headers=headers)

            body = await request.read()
            if json_request:
                if not body:
                    raise HTTPResponse(body="Missing request JSON body",
                                       status=400)
                if request.charset and request.charset.lower(
                ) != "utf-8" and request.charset.lower() != "utf8":
                    raise HTTPResponse(
                        body="Request character set must be UTF-8", status=400)
                try:
                    body_string = body.decode("utf-8")
                    rapu_request.json = jsonlib.loads(body_string)
                except jsonlib.decoder.JSONDecodeError:
                    raise HTTPResponse(body="Invalid request JSON body",
                                       status=400)
                except UnicodeDecodeError:
                    raise HTTPResponse(body="Request body is not valid UTF-8",
                                       status=400)
            else:
                if body not in {b"", b"{}"}:
                    raise HTTPResponse(
                        body="No request body allowed for this operation",
                        status=400)

            callback_kwargs = dict(request.match_info)
            if callback_with_request:
                callback_kwargs["request"] = rapu_request

            if schema_request:
                content_type = self.check_schema_headers(request)
                callback_kwargs["content_type"] = content_type

            try:
                data = await callback(**callback_kwargs)
                status = 200
                headers = {}
            except HTTPResponse as ex:
                data = ex.body
                status = ex.status
                headers = ex.headers
            headers.update(
                self.cors_and_server_headers_for_request(request=rapu_request))

            if isinstance(data, (dict, list)):
                resp_bytes = json_encode(data,
                                         binary=True,
                                         sort_keys=True,
                                         compact=True)
            elif isinstance(data, str):
                if "Content-Type" not in headers:
                    headers["Content-Type"] = "text/plain; charset=utf-8"
                resp_bytes = data.encode("utf-8")
            else:
                resp_bytes = data

            # On 204 - NO CONTENT there is no point of calculating cache headers
            if 200 >= status <= 299:
                if resp_bytes:
                    etag = '"{}"'.format(hashlib.md5(resp_bytes).hexdigest())
                else:
                    etag = '""'
                if_none_match = request.headers.get("if-none-match")
                if if_none_match and if_none_match.replace("W/", "") == etag:
                    status = 304
                    resp_bytes = b""

                headers["access-control-expose-headers"] = "etag"
                headers["etag"] = etag

            resp = aiohttp.web.Response(body=resp_bytes,
                                        status=status,
                                        headers=headers)
        except HTTPResponse as ex:
            if isinstance(ex.body, str):
                resp = aiohttp.web.Response(text=ex.body,
                                            status=ex.status,
                                            headers=ex.headers)
            else:
                resp = aiohttp.web.Response(body=ex.body,
                                            status=ex.status,
                                            headers=ex.headers)
        except asyncio.CancelledError:
            self.log.debug("Client closed connection")
            raise
        except Exception as ex:  # pylint: disable=broad-except
            self.stats.unexpected_exception(ex=ex,
                                            where="rapu_wrapped_callback")
            self.log.exception("Unexpected error handling user request: %s %s",
                               request.method, request.url)
            resp = aiohttp.web.Response(text="Internal Server Error",
                                        status=500)
        finally:
            self.stats.timing(
                self.app_request_metric,
                time.monotonic() - start_time,
                tags={
                    "path": path_for_stats,
                    # no `resp` means that we had a failure in exception handler
                    "result": resp.status if resp else 0,
                    "method": request.method,
                })

        return resp

    def route(self,
              path,
              *,
              callback,
              method,
              schema_request=False,
              with_request=None,
              json_body=None):
        # pretty path for statsd reporting
        path_for_stats = re.sub(r"<[\w:]+>", "x", path)

        # bottle compatible routing
        aio_route = path
        aio_route = re.sub(r"<(\w+):path>", r"{\1:.+}", aio_route)
        aio_route = re.sub(r"<(\w+)>", r"{\1}", aio_route)

        if (method in {"POST", "PUT"}) and with_request is None:
            with_request = True

        if with_request and json_body is None:
            json_body = True

        async def wrapped_callback(request):
            return await self._handle_request(
                request=request,
                path_for_stats=path_for_stats,
                callback=callback,
                schema_request=schema_request,
                callback_with_request=with_request,
                json_request=json_body,
            )

        async def wrapped_cors(request):
            return await self._handle_request(
                request=request,
                path_for_stats=path_for_stats,
                callback=None,
            )

        self.app.router.add_route(method, aio_route, wrapped_callback)
        try:
            self.app.router.add_route("OPTIONS", aio_route, wrapped_cors)
        except RuntimeError as ex:
            if "Added route will never be executed, method OPTIONS is already registered" not in str(
                    ex):
                raise

    async def http_request(self,
                           url,
                           *,
                           method="GET",
                           json=None,
                           timeout=10.0,
                           verify=True,
                           proxy=None):
        close_session = False

        if isinstance(verify, str):
            sslcontext = ssl.create_default_context(cadata=verify)
        else:
            sslcontext = None

        if proxy:
            connector = aiohttp_socks.SocksConnector(
                socks_ver=aiohttp_socks.SocksVer.SOCKS5,
                host=proxy["host"],
                port=proxy["port"],
                username=proxy["username"],
                password=proxy["password"],
                rdns=False,
                verify_ssl=verify,
                ssl_context=sslcontext,
            )
            session = aiohttp.ClientSession(connector=connector)
            close_session = True
        elif sslcontext:
            conn = aiohttp.TCPConnector(ssl_context=sslcontext)
            session = aiohttp.ClientSession(connector=conn)
            close_session = True
        elif verify is True:
            session = self.http_client_v
        elif verify is False:
            session = self.http_client_no_v
        else:
            raise ValueError("invalid arguments to http_request")

        func = getattr(session, method.lower())
        try:
            with async_timeout.timeout(timeout):
                async with func(url, json=json) as response:
                    if response.headers.get("content-type",
                                            "").startswith("application/json"):
                        resp_content = await response.json()
                    else:
                        resp_content = await response.text()
                    result = HTTPResponse(body=resp_content,
                                          status=response.status)
        finally:
            if close_session:
                await session.close()

        return result

    def run(self, *, host, port):
        aiohttp.web.run_app(
            app=self.app,
            host=host,
            port=port,
            access_log_format=
            '%Tfs %{x-client-ip}i "%r" %s "%{user-agent}i" response=%bb request_body=%{content-length}ib',
        )

    def add_routes(self):
        pass  # Override in sub-classes