async def __call__(self, scope, receive, send): http_body_bytes = await self.receive_http_body(scope, receive, send) headers = {k.decode(): v.decode() for k, v in scope["headers"]} # scope["router"] and scope["endpoint"] contain references to a router # and endpoint object, respectively, which each in turn contain a # reference to the Serve client, which cannot be serialized. # The solution is to delete these from scope, as they will not be used. del scope["router"] del scope["endpoint"] # Modify the path and root path so that reverse lookups and redirection # work as expected. We do this here instead of in replicas so it can be # changed without restarting the replicas. scope["path"] = scope["path"].replace(self.path_prefix, "", 1) scope["root_path"] = self.path_prefix handle = self.handle.options( method_name=headers.get("X-SERVE-CALL-METHOD".lower(), DEFAULT.VALUE), shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE), http_method=scope["method"].upper(), http_headers=headers) # NOTE(edoakes): it's important that we defer building the starlette # request until it reaches the backend replica to avoid unnecessary # serialization cost, so we use a simple dataclass here. request = HTTPRequestWrapper(scope, http_body_bytes) retries = 0 backoff_time_s = 0.05 while retries < MAX_REPLICA_FAILURE_RETRIES: object_ref = await handle.remote(request) try: result = await object_ref break except RayActorError: logger.warning( "Request failed due to replica failure. There are " f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries " "remaining.") await asyncio.sleep(backoff_time_s) backoff_time_s *= 2 retries += 1 if isinstance(result, RayTaskError): error_message = "Task Error. Traceback: {}.".format(result) await Response(error_message, status_code=500).send(scope, receive, send) elif isinstance(result, starlette.responses.Response): await result(scope, receive, send) else: await Response(result).send(scope, receive, send)
async def _send_request_to_handle(handle, scope, receive, send): http_body_bytes = await receive_http_body(scope, receive, send) headers = {k.decode(): v.decode() for k, v in scope["headers"]} handle = handle.options( method_name=headers.get("X-SERVE-CALL-METHOD".lower(), DEFAULT.VALUE), shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE), http_method=scope["method"].upper(), http_headers=headers, ) # scope["router"] and scope["endpoint"] contain references to a router # and endpoint object, respectively, which each in turn contain a # reference to the Serve client, which cannot be serialized. # The solution is to delete these from scope, as they will not be used. # TODO(edoakes): this can be removed once we deprecate the old API. if "router" in scope: del scope["router"] if "endpoint" in scope: del scope["endpoint"] # NOTE(edoakes): it's important that we defer building the starlette # request until it reaches the backend replica to avoid unnecessary # serialization cost, so we use a simple dataclass here. request = HTTPRequestWrapper(scope, http_body_bytes) # Perform a pickle here to improve latency. Stdlib pickle for simple # dataclasses are 10-100x faster than cloudpickle. request = pickle.dumps(request) retries = 0 backoff_time_s = 0.05 while retries < MAX_REPLICA_FAILURE_RETRIES: object_ref = await handle.remote(request) try: result = await object_ref break except RayActorError: logger.warning("Request failed due to replica failure. There are " f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries " "remaining.") await asyncio.sleep(backoff_time_s) backoff_time_s *= 2 retries += 1 if isinstance(result, RayTaskError): error_message = "Task Error. Traceback: {}.".format(result) await Response(error_message, status_code=500).send(scope, receive, send) elif isinstance(result, starlette.responses.Response): await result(scope, receive, send) else: await Response(result).send(scope, receive, send)
async def _send_request_to_handle(handle, scope, receive, send) -> str: http_body_bytes = await receive_http_body(scope, receive, send) # NOTE(edoakes): it's important that we defer building the starlette # request until it reaches the replica to avoid unnecessary # serialization cost, so we use a simple dataclass here. request = HTTPRequestWrapper(scope, http_body_bytes) # Perform a pickle here to improve latency. Stdlib pickle for simple # dataclasses are 10-100x faster than cloudpickle. request = pickle.dumps(request) retries = 0 backoff_time_s = 0.05 loop = asyncio.get_event_loop() # We have received all the http request conent. The next `receive` # call might never arrive; if it does, it can only be `http.disconnect`. client_disconnection_task = loop.create_task(receive()) while retries < MAX_REPLICA_FAILURE_RETRIES: assignment_task = loop.create_task(handle.remote(request)) done, _ = await asyncio.wait( [assignment_task, client_disconnection_task], return_when=FIRST_COMPLETED) if client_disconnection_task in done: message = await client_disconnection_task assert message["type"] == "http.disconnect", ( "Received additional request payload that's not disconnect. " "This is an invalid HTTP state.") logger.warning( f"Client from {scope['client']} disconnected, cancelling the " "request.") # This will make the .result() to raise cancelled error. assignment_task.cancel() try: object_ref = await assignment_task result = await object_ref client_disconnection_task.cancel() break except asyncio.CancelledError: # Here because the client disconnected, we will return a custom # error code for metric tracking. return DISCONNECT_ERROR_CODE except RayTaskError as error: error_message = "Task Error. Traceback: {}.".format(error) await Response(error_message, status_code=500).send(scope, receive, send) return "500" except RayActorError: logger.debug("Request failed due to replica failure. There are " f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries " "remaining.") await asyncio.sleep(backoff_time_s) # Be careful about the expotential backoff scaling here. # Assuming 10 retries, 1.5x scaling means the last retry is 38x the # initial backoff time, while 2x scaling means 512x the initial. backoff_time_s *= 1.5 retries += 1 else: error_message = "Task failed with " f"{MAX_REPLICA_FAILURE_RETRIES} retries." await Response(error_message, status_code=500).send(scope, receive, send) return "500" if isinstance(result, (starlette.responses.Response, RawASGIResponse)): await result(scope, receive, send) return str(result.status_code) else: await Response(result).send(scope, receive, send) return "200"
async def _send_request_to_handle(handle, scope, receive, send) -> str: http_body_bytes = await receive_http_body(scope, receive, send) headers = {k.decode(): v.decode() for k, v in scope["headers"]} handle = handle.options( method_name=headers.get("X-SERVE-CALL-METHOD".lower(), DEFAULT.VALUE), shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE), http_method=scope["method"].upper(), http_headers=headers, ) # scope["router"] and scope["endpoint"] contain references to a router # and endpoint object, respectively, which each in turn contain a # reference to the Serve client, which cannot be serialized. # The solution is to delete these from scope, as they will not be used. # TODO(edoakes): this can be removed once we deprecate the old API. if "router" in scope: del scope["router"] if "endpoint" in scope: del scope["endpoint"] # NOTE(edoakes): it's important that we defer building the starlette # request until it reaches the replica to avoid unnecessary # serialization cost, so we use a simple dataclass here. request = HTTPRequestWrapper(scope, http_body_bytes) # Perform a pickle here to improve latency. Stdlib pickle for simple # dataclasses are 10-100x faster than cloudpickle. request = pickle.dumps(request) retries = 0 backoff_time_s = 0.05 loop = asyncio.get_event_loop() # We have received all the http request conent. The next `receive` # call might never arrive; if it does, it can only be `http.disconnect`. client_disconnection_task = loop.create_task(receive()) while retries < MAX_REPLICA_FAILURE_RETRIES: assignment_task = loop.create_task(handle.remote(request)) done, _ = await asyncio.wait( [assignment_task, client_disconnection_task], return_when=FIRST_COMPLETED) if client_disconnection_task in done: message = await client_disconnection_task assert message["type"] == "http.disconnect", ( "Received additional request payload that's not disconnect. " "This is an invalid HTTP state.") logger.warning( f"Client from {scope['client']} disconnected, cancelling the " "request.") # This will make the .result() to raise cancelled error. assignment_task.cancel() try: object_ref = await assignment_task result = await object_ref client_disconnection_task.cancel() break except asyncio.CancelledError: # Here because the client disconnected, we will return a custom # error code for metric tracking. return DISCONNECT_ERROR_CODE except RayTaskError as error: error_message = "Task Error. Traceback: {}.".format(error) await Response(error_message, status_code=500).send(scope, receive, send) return "500" except RayActorError: logger.warning("Request failed due to replica failure. There are " f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries " "remaining.") await asyncio.sleep(backoff_time_s) # Be careful about the expotential backoff scaling here. # Assuming 10 retries, 1.5x scaling means the last retry is 38x the # initial backoff time, while 2x scaling means 512x the initial. backoff_time_s *= 1.5 retries += 1 else: error_message = ("Task failed with " f"{MAX_REPLICA_FAILURE_RETRIES} retries.") await Response(error_message, status_code=500).send(scope, receive, send) return "500" if isinstance(result, (starlette.responses.Response, RawASGIResponse)): await result(scope, receive, send) return str(result.status_code) else: await Response(result).send(scope, receive, send) return "200"
async def __call__(self, scope, receive, send): """Implements the ASGI protocol. See details at: https://asgi.readthedocs.io/en/latest/specs/index.html. """ assert scope["type"] == "http" self.request_counter.inc(tags={"route": scope["path"]}) if scope["path"] == "/-/routes": return await starlette.responses.JSONResponse( self.router.route_info)(scope, receive, send) route_prefix, handle = self.router.match_route(scope["path"], scope["method"]) if route_prefix is None: return await self._not_found(scope, receive, send) http_body_bytes = await receive_http_body(scope, receive, send) # Modify the path and root path so that reverse lookups and redirection # work as expected. We do this here instead of in replicas so it can be # changed without restarting the replicas. if route_prefix != "/": assert not route_prefix.endswith("/") scope["path"] = scope["path"].replace(route_prefix, "", 1) scope["root_path"] = route_prefix headers = {k.decode(): v.decode() for k, v in scope["headers"]} handle = handle.options( method_name=headers.get("X-SERVE-CALL-METHOD".lower(), DEFAULT.VALUE), shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE), http_method=scope["method"].upper(), http_headers=headers) # NOTE(edoakes): it's important that we defer building the starlette # request until it reaches the backend replica to avoid unnecessary # serialization cost, so we use a simple dataclass here. request = HTTPRequestWrapper(scope, http_body_bytes) retries = 0 backoff_time_s = 0.05 while retries < MAX_REPLICA_FAILURE_RETRIES: object_ref = await handle.remote(request) try: result = await object_ref break except RayActorError: logger.warning( "Request failed due to replica failure. There are " f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries " "remaining.") await asyncio.sleep(backoff_time_s) backoff_time_s *= 2 retries += 1 if isinstance(result, RayTaskError): error_message = "Task Error. Traceback: {}.".format(result) await Response( error_message, status_code=500).send(scope, receive, send) elif isinstance(result, starlette.responses.Response): await result(scope, receive, send) else: await Response(result).send(scope, receive, send)