Example #1
0
    async def __call__(self, scope, receive, send):
        http_body_bytes = await self.receive_http_body(scope, receive, send)

        headers = {k.decode(): v.decode() for k, v in scope["headers"]}

        # scope["router"] and scope["endpoint"] contain references to a router
        # and endpoint object, respectively, which each in turn contain a
        # reference to the Serve client, which cannot be serialized.
        # The solution is to delete these from scope, as they will not be used.
        del scope["router"]
        del scope["endpoint"]

        # Modify the path and root path so that reverse lookups and redirection
        # work as expected. We do this here instead of in replicas so it can be
        # changed without restarting the replicas.
        scope["path"] = scope["path"].replace(self.path_prefix, "", 1)
        scope["root_path"] = self.path_prefix
        handle = self.handle.options(
            method_name=headers.get("X-SERVE-CALL-METHOD".lower(),
                                    DEFAULT.VALUE),
            shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE),
            http_method=scope["method"].upper(),
            http_headers=headers)

        # NOTE(edoakes): it's important that we defer building the starlette
        # request until it reaches the backend replica to avoid unnecessary
        # serialization cost, so we use a simple dataclass here.
        request = HTTPRequestWrapper(scope, http_body_bytes)

        retries = 0
        backoff_time_s = 0.05
        while retries < MAX_REPLICA_FAILURE_RETRIES:
            object_ref = await handle.remote(request)
            try:
                result = await object_ref
                break
            except RayActorError:
                logger.warning(
                    "Request failed due to replica failure. There are "
                    f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries "
                    "remaining.")
                await asyncio.sleep(backoff_time_s)
                backoff_time_s *= 2
                retries += 1

        if isinstance(result, RayTaskError):
            error_message = "Task Error. Traceback: {}.".format(result)
            await Response(error_message,
                           status_code=500).send(scope, receive, send)
        elif isinstance(result, starlette.responses.Response):
            await result(scope, receive, send)
        else:
            await Response(result).send(scope, receive, send)
Example #2
0
async def _send_request_to_handle(handle, scope, receive, send):
    http_body_bytes = await receive_http_body(scope, receive, send)

    headers = {k.decode(): v.decode() for k, v in scope["headers"]}
    handle = handle.options(
        method_name=headers.get("X-SERVE-CALL-METHOD".lower(), DEFAULT.VALUE),
        shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE),
        http_method=scope["method"].upper(),
        http_headers=headers,
    )

    # scope["router"] and scope["endpoint"] contain references to a router
    # and endpoint object, respectively, which each in turn contain a
    # reference to the Serve client, which cannot be serialized.
    # The solution is to delete these from scope, as they will not be used.
    # TODO(edoakes): this can be removed once we deprecate the old API.
    if "router" in scope:
        del scope["router"]
    if "endpoint" in scope:
        del scope["endpoint"]

    # NOTE(edoakes): it's important that we defer building the starlette
    # request until it reaches the backend replica to avoid unnecessary
    # serialization cost, so we use a simple dataclass here.
    request = HTTPRequestWrapper(scope, http_body_bytes)
    # Perform a pickle here to improve latency. Stdlib pickle for simple
    # dataclasses are 10-100x faster than cloudpickle.
    request = pickle.dumps(request)

    retries = 0
    backoff_time_s = 0.05
    while retries < MAX_REPLICA_FAILURE_RETRIES:
        object_ref = await handle.remote(request)
        try:
            result = await object_ref
            break
        except RayActorError:
            logger.warning("Request failed due to replica failure. There are "
                           f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries "
                           "remaining.")
            await asyncio.sleep(backoff_time_s)
            backoff_time_s *= 2
            retries += 1

    if isinstance(result, RayTaskError):
        error_message = "Task Error. Traceback: {}.".format(result)
        await Response(error_message,
                       status_code=500).send(scope, receive, send)
    elif isinstance(result, starlette.responses.Response):
        await result(scope, receive, send)
    else:
        await Response(result).send(scope, receive, send)
Example #3
0
async def _send_request_to_handle(handle, scope, receive, send) -> str:
    http_body_bytes = await receive_http_body(scope, receive, send)

    # NOTE(edoakes): it's important that we defer building the starlette
    # request until it reaches the replica to avoid unnecessary
    # serialization cost, so we use a simple dataclass here.
    request = HTTPRequestWrapper(scope, http_body_bytes)
    # Perform a pickle here to improve latency. Stdlib pickle for simple
    # dataclasses are 10-100x faster than cloudpickle.
    request = pickle.dumps(request)

    retries = 0
    backoff_time_s = 0.05
    loop = asyncio.get_event_loop()
    # We have received all the http request conent. The next `receive`
    # call might never arrive; if it does, it can only be `http.disconnect`.
    client_disconnection_task = loop.create_task(receive())
    while retries < MAX_REPLICA_FAILURE_RETRIES:
        assignment_task = loop.create_task(handle.remote(request))
        done, _ = await asyncio.wait(
            [assignment_task, client_disconnection_task],
            return_when=FIRST_COMPLETED)
        if client_disconnection_task in done:
            message = await client_disconnection_task
            assert message["type"] == "http.disconnect", (
                "Received additional request payload that's not disconnect. "
                "This is an invalid HTTP state.")

            logger.warning(
                f"Client from {scope['client']} disconnected, cancelling the "
                "request.")
            # This will make the .result() to raise cancelled error.
            assignment_task.cancel()
        try:
            object_ref = await assignment_task
            result = await object_ref
            client_disconnection_task.cancel()
            break
        except asyncio.CancelledError:
            # Here because the client disconnected, we will return a custom
            # error code for metric tracking.
            return DISCONNECT_ERROR_CODE
        except RayTaskError as error:
            error_message = "Task Error. Traceback: {}.".format(error)
            await Response(error_message,
                           status_code=500).send(scope, receive, send)
            return "500"
        except RayActorError:
            logger.debug("Request failed due to replica failure. There are "
                         f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries "
                         "remaining.")
            await asyncio.sleep(backoff_time_s)
            # Be careful about the expotential backoff scaling here.
            # Assuming 10 retries, 1.5x scaling means the last retry is 38x the
            # initial backoff time, while 2x scaling means 512x the initial.
            backoff_time_s *= 1.5
            retries += 1
    else:
        error_message = "Task failed with " f"{MAX_REPLICA_FAILURE_RETRIES} retries."
        await Response(error_message,
                       status_code=500).send(scope, receive, send)
        return "500"

    if isinstance(result, (starlette.responses.Response, RawASGIResponse)):
        await result(scope, receive, send)
        return str(result.status_code)
    else:
        await Response(result).send(scope, receive, send)
        return "200"
Example #4
0
async def _send_request_to_handle(handle, scope, receive, send) -> str:
    http_body_bytes = await receive_http_body(scope, receive, send)

    headers = {k.decode(): v.decode() for k, v in scope["headers"]}
    handle = handle.options(
        method_name=headers.get("X-SERVE-CALL-METHOD".lower(), DEFAULT.VALUE),
        shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE),
        http_method=scope["method"].upper(),
        http_headers=headers,
    )

    # scope["router"] and scope["endpoint"] contain references to a router
    # and endpoint object, respectively, which each in turn contain a
    # reference to the Serve client, which cannot be serialized.
    # The solution is to delete these from scope, as they will not be used.
    # TODO(edoakes): this can be removed once we deprecate the old API.
    if "router" in scope:
        del scope["router"]
    if "endpoint" in scope:
        del scope["endpoint"]

    # NOTE(edoakes): it's important that we defer building the starlette
    # request until it reaches the replica to avoid unnecessary
    # serialization cost, so we use a simple dataclass here.
    request = HTTPRequestWrapper(scope, http_body_bytes)
    # Perform a pickle here to improve latency. Stdlib pickle for simple
    # dataclasses are 10-100x faster than cloudpickle.
    request = pickle.dumps(request)

    retries = 0
    backoff_time_s = 0.05
    loop = asyncio.get_event_loop()
    # We have received all the http request conent. The next `receive`
    # call might never arrive; if it does, it can only be `http.disconnect`.
    client_disconnection_task = loop.create_task(receive())
    while retries < MAX_REPLICA_FAILURE_RETRIES:
        assignment_task = loop.create_task(handle.remote(request))
        done, _ = await asyncio.wait(
            [assignment_task, client_disconnection_task],
            return_when=FIRST_COMPLETED)
        if client_disconnection_task in done:
            message = await client_disconnection_task
            assert message["type"] == "http.disconnect", (
                "Received additional request payload that's not disconnect. "
                "This is an invalid HTTP state.")

            logger.warning(
                f"Client from {scope['client']} disconnected, cancelling the "
                "request.")
            # This will make the .result() to raise cancelled error.
            assignment_task.cancel()
        try:
            object_ref = await assignment_task
            result = await object_ref
            client_disconnection_task.cancel()
            break
        except asyncio.CancelledError:
            # Here because the client disconnected, we will return a custom
            # error code for metric tracking.
            return DISCONNECT_ERROR_CODE
        except RayTaskError as error:
            error_message = "Task Error. Traceback: {}.".format(error)
            await Response(error_message,
                           status_code=500).send(scope, receive, send)
            return "500"
        except RayActorError:
            logger.warning("Request failed due to replica failure. There are "
                           f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries "
                           "remaining.")
            await asyncio.sleep(backoff_time_s)
            # Be careful about the expotential backoff scaling here.
            # Assuming 10 retries, 1.5x scaling means the last retry is 38x the
            # initial backoff time, while 2x scaling means 512x the initial.
            backoff_time_s *= 1.5
            retries += 1
    else:
        error_message = ("Task failed with "
                         f"{MAX_REPLICA_FAILURE_RETRIES} retries.")
        await Response(error_message,
                       status_code=500).send(scope, receive, send)
        return "500"

    if isinstance(result, (starlette.responses.Response, RawASGIResponse)):
        await result(scope, receive, send)
        return str(result.status_code)
    else:
        await Response(result).send(scope, receive, send)
        return "200"
Example #5
0
    async def __call__(self, scope, receive, send):
        """Implements the ASGI protocol.

        See details at:
            https://asgi.readthedocs.io/en/latest/specs/index.html.
        """

        assert scope["type"] == "http"
        self.request_counter.inc(tags={"route": scope["path"]})

        if scope["path"] == "/-/routes":
            return await starlette.responses.JSONResponse(
                self.router.route_info)(scope, receive, send)

        route_prefix, handle = self.router.match_route(scope["path"],
                                                       scope["method"])
        if route_prefix is None:
            return await self._not_found(scope, receive, send)

        http_body_bytes = await receive_http_body(scope, receive, send)

        # Modify the path and root path so that reverse lookups and redirection
        # work as expected. We do this here instead of in replicas so it can be
        # changed without restarting the replicas.
        if route_prefix != "/":
            assert not route_prefix.endswith("/")
            scope["path"] = scope["path"].replace(route_prefix, "", 1)
            scope["root_path"] = route_prefix

        headers = {k.decode(): v.decode() for k, v in scope["headers"]}
        handle = handle.options(
            method_name=headers.get("X-SERVE-CALL-METHOD".lower(),
                                    DEFAULT.VALUE),
            shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE),
            http_method=scope["method"].upper(),
            http_headers=headers)

        # NOTE(edoakes): it's important that we defer building the starlette
        # request until it reaches the backend replica to avoid unnecessary
        # serialization cost, so we use a simple dataclass here.
        request = HTTPRequestWrapper(scope, http_body_bytes)

        retries = 0
        backoff_time_s = 0.05
        while retries < MAX_REPLICA_FAILURE_RETRIES:
            object_ref = await handle.remote(request)
            try:
                result = await object_ref
                break
            except RayActorError:
                logger.warning(
                    "Request failed due to replica failure. There are "
                    f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries "
                    "remaining.")
                await asyncio.sleep(backoff_time_s)
                backoff_time_s *= 2
                retries += 1

        if isinstance(result, RayTaskError):
            error_message = "Task Error. Traceback: {}.".format(result)
            await Response(
                error_message, status_code=500).send(scope, receive, send)
        elif isinstance(result, starlette.responses.Response):
            await result(scope, receive, send)
        else:
            await Response(result).send(scope, receive, send)