Esempio n. 1
0
 def send_get_response(result: Any) -> None:
     """Pushes GetResponses to the main DataPath loop to send
     to the client. This is called when the object is ready
     on the server side."""
     try:
         serialized = dumps_from_server(result, client_id, self)
         total_size = len(serialized)
         assert total_size > 0, "Serialized object cannot be zero bytes"
         total_chunks = math.ceil(total_size /
                                  OBJECT_TRANSFER_CHUNK_SIZE)
         for chunk_id in range(request.start_chunk_id,
                               total_chunks):
             start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
             end = min(total_size, (chunk_id + 1) *
                       OBJECT_TRANSFER_CHUNK_SIZE)
             get_resp = ray_client_pb2.GetResponse(
                 valid=True,
                 data=serialized[start:end],
                 chunk_id=chunk_id,
                 total_chunks=total_chunks,
                 total_size=total_size,
             )
             chunk_resp = ray_client_pb2.DataResponse(
                 get=get_resp, req_id=req_id)
             result_queue.put(chunk_resp)
     except Exception as exc:
         get_resp = ray_client_pb2.GetResponse(
             valid=False, error=cloudpickle.dumps(exc))
         resp = ray_client_pb2.DataResponse(get=get_resp,
                                            req_id=req_id)
         result_queue.put(resp)
Esempio n. 2
0
    def Datapath(self, request_iterator, context):
        metadata = {k: v for k, v in context.invocation_metadata()}
        client_id = metadata["client_id"]
        if client_id == "":
            logger.error("Client connecting with no client_id")
            return
        logger.info(f"New data connection from client {client_id}")
        try:
            with self._clients_lock:
                with disable_client_hook():
                    if self._num_clients == 0 and not ray.is_initialized():
                        self.ray_connect_handler()
                self._num_clients += 1
            for req in request_iterator:
                resp = None
                req_type = req.WhichOneof("type")
                if req_type == "get":
                    get_resp = self.basic_service._get_object(
                        req.get, client_id)
                    resp = ray_client_pb2.DataResponse(get=get_resp)
                elif req_type == "put":
                    put_resp = self.basic_service._put_object(
                        req.put, client_id)
                    resp = ray_client_pb2.DataResponse(put=put_resp)
                elif req_type == "release":
                    released = []
                    for rel_id in req.release.ids:
                        rel = self.basic_service.release(client_id, rel_id)
                        released.append(rel)
                    resp = ray_client_pb2.DataResponse(
                        release=ray_client_pb2.ReleaseResponse(ok=released))
                elif req_type == "connection_info":
                    resp = ray_client_pb2.DataResponse(
                        connection_info=self._build_connection_response())
                else:
                    raise Exception(f"Unreachable code: Request type "
                                    f"{req_type} not handled in Datapath")
                resp.req_id = req.req_id
                yield resp
        except grpc.RpcError as e:
            logger.debug(f"Closing data channel: {e}")
        finally:
            logger.info(f"Lost data connection from client {client_id}")
            self.basic_service.release_all(client_id)

            with self._clients_lock:
                self._num_clients -= 1

            with disable_client_hook():
                if self._num_clients == 0:
                    ray.shutdown()
Esempio n. 3
0
 def Datapath(self, request_iterator, context):
     metadata = {k: v for k, v in context.invocation_metadata()}
     client_id = metadata["client_id"]
     if client_id == "":
         logger.error("Client connecting with no client_id")
         return
     logger.info(f"New data connection from client {client_id}")
     try:
         with self._clients_lock:
             self._num_clients += 1
         for req in request_iterator:
             resp = None
             req_type = req.WhichOneof("type")
             if req_type == "get":
                 get_resp = self.basic_service._get_object(
                     req.get, client_id)
                 resp = ray_client_pb2.DataResponse(get=get_resp)
             elif req_type == "put":
                 put_resp = self.basic_service._put_object(
                     req.put, client_id)
                 resp = ray_client_pb2.DataResponse(put=put_resp)
             elif req_type == "release":
                 released = []
                 for rel_id in req.release.ids:
                     rel = self.basic_service.release(client_id, rel_id)
                     released.append(rel)
                 resp = ray_client_pb2.DataResponse(
                     release=ray_client_pb2.ReleaseResponse(ok=released))
             elif req_type == "connection_info":
                 with self._clients_lock:
                     cur_num_clients = self._num_clients
                 info = ray_client_pb2.ConnectionInfoResponse(
                     num_clients=cur_num_clients,
                     python_version="{}.{}.{}".format(
                         sys.version_info[0], sys.version_info[1],
                         sys.version_info[2]),
                     ray_version=ray.__version__,
                     ray_commit=ray.__commit__)
                 resp = ray_client_pb2.DataResponse(connection_info=info)
             else:
                 raise Exception(f"Unreachable code: Request type "
                                 f"{req_type} not handled in Datapath")
             resp.req_id = req.req_id
             yield resp
     except grpc.RpcError as e:
         logger.debug(f"Closing data channel: {e}")
     finally:
         logger.info(f"Lost data connection from client {client_id}")
         self.basic_service.release_all(client_id)
         with self._clients_lock:
             self._num_clients -= 1
Esempio n. 4
0
    def Datapath(self, request_iterator, context):
        client_id = _get_client_id_from_context(context)
        if client_id == "":
            return

        # Create Placeholder *before* reading the first request.
        server = self.proxy_manager.create_specific_server(client_id)
        try:
            with self.clients_lock:
                self.num_clients += 1

            logger.info(f"New data connection from client {client_id}: ")
            init_req = next(request_iterator)
            try:
                modified_init_req, job_config = prepare_runtime_init_req(
                    init_req)
                if not self.proxy_manager.start_specific_server(
                        client_id, job_config):
                    logger.error(
                        f"Server startup failed for client: {client_id}, "
                        f"using JobConfig: {job_config}!")
                    raise RuntimeError(
                        "Starting Ray client server failed. This is most "
                        "likely because the runtime_env failed to be "
                        "installed. See ray_client_server_[port].err on the "
                        "head node of the cluster for the relevant logs.")
                channel = self.proxy_manager.get_channel(client_id)
                if channel is None:
                    logger.error(f"Channel not found for {client_id}")
                    raise RuntimeError(
                        "Proxy failed to Connect to backend! Check "
                        "`ray_client_server.err` on the cluster.")
                stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
            except Exception:
                init_resp = ray_client_pb2.DataResponse(
                    init=ray_client_pb2.InitResponse(
                        ok=False, msg=traceback.format_exc()))
                init_resp.req_id = init_req.req_id
                yield init_resp
                return None

            new_iter = chain([modified_init_req], request_iterator)
            resp_stream = stub.Datapath(
                new_iter, metadata=[("client_id", client_id)])
            for resp in resp_stream:
                yield self.modify_connection_info_resp(resp)
        except Exception:
            logger.exception("Proxying Datapath failed!")
        finally:
            server.set_result(None)
            with self.clients_lock:
                logger.debug(f"Client detached: {client_id}")
                self.num_clients -= 1
Esempio n. 5
0
 def send_get_response(result: Any) -> None:
     """Pushes a GetResponse to the main DataPath loop to send
     to the client. This is called when the object is ready
     on the server side."""
     try:
         serialized = dumps_from_server(result, client_id, self)
         get_resp = ray_client_pb2.GetResponse(valid=True,
                                               data=serialized)
     except Exception as exc:
         get_resp = ray_client_pb2.GetResponse(
             valid=False, error=cloudpickle.dumps(exc))
     resp = ray_client_pb2.DataResponse(get=get_resp,
                                        req_id=req_id)
     result_queue.put(resp)
Esempio n. 6
0
 def modify_connection_info_resp(self,
                                 init_resp: ray_client_pb2.DataResponse
                                 ) -> ray_client_pb2.DataResponse:
     """
     Modify the `num_clients` returned the ConnectionInfoResponse because
     individual SpecificServers only have **one** client.
     """
     init_type = init_resp.WhichOneof("type")
     if init_type != "connection_info":
         return init_resp
     modified_resp = ray_client_pb2.DataResponse()
     modified_resp.CopyFrom(init_resp)
     with self.clients_lock:
         modified_resp.connection_info.num_clients = self.num_clients
     return modified_resp
Esempio n. 7
0
 def _init(self, req_init, client_id):
     with self.clients_lock:
         threshold = int(CLIENT_SERVER_MAX_THREADS / 2)
         if self.num_clients >= threshold:
             logger.warning(
                 f"[Data Servicer]: Num clients {self.num_clients} "
                 f"has reached the threshold {threshold}. "
                 f"Rejecting client: {client_id}. ")
             if log_once("client_threshold"):
                 logger.warning(
                     "You can configure the client connection "
                     "threshold by setting the "
                     "RAY_CLIENT_SERVER_MAX_THREADS env var "
                     f"(currently set to {CLIENT_SERVER_MAX_THREADS}).")
             return None
         resp_init = self.basic_service.Init(req_init)
         self.num_clients += 1
         return ray_client_pb2.DataResponse(init=resp_init, )
Esempio n. 8
0
    def Datapath(self, request_iterator, context):
        start_time = time.time()
        # set to True if client shuts down gracefully
        cleanup_requested = False
        metadata = {k: v for k, v in context.invocation_metadata()}
        client_id = metadata.get("client_id")
        if client_id is None:
            logger.error("Client connecting with no client_id")
            return
        logger.debug(f"New data connection from client {client_id}: ")
        accepted_connection = self._init(client_id, context, start_time)
        response_cache = self.response_caches[client_id]
        # Set to False if client requests a reconnect grace period of 0
        reconnect_enabled = True
        if not accepted_connection:
            return
        try:
            request_queue = Queue()
            queue_filler_thread = Thread(target=fill_queue,
                                         daemon=True,
                                         args=(request_iterator,
                                               request_queue))
            queue_filler_thread.start()
            """For non `async get` requests, this loop yields immediately
            For `async get` requests, this loop:
                 1) does not yield, it just continues
                 2) When the result is ready, it yields
            """
            for req in iter(request_queue.get, None):
                if isinstance(req, ray_client_pb2.DataResponse):
                    # Early shortcut if this is the result of an async get.
                    yield req
                    continue

                assert isinstance(req, ray_client_pb2.DataRequest)
                if _should_cache(req) and reconnect_enabled:
                    cached_resp = response_cache.check_cache(req.req_id)
                    if isinstance(cached_resp, Exception):
                        # Cache state is invalid, raise exception
                        raise cached_resp
                    if cached_resp is not None:
                        yield cached_resp
                        continue

                resp = None
                req_type = req.WhichOneof("type")
                if req_type == "init":
                    resp_init = self.basic_service.Init(req.init)
                    resp = ray_client_pb2.DataResponse(init=resp_init, )
                    with self.clients_lock:
                        self.reconnect_grace_periods[
                            client_id] = req.init.reconnect_grace_period
                        if req.init.reconnect_grace_period == 0:
                            reconnect_enabled = False

                elif req_type == "get":
                    if req.get.asynchronous:
                        get_resp = self.basic_service._async_get_object(
                            req.get, client_id, req.req_id, request_queue)
                        if get_resp is None:
                            # Skip sending a response for this request and
                            # continue to the next requst. The response for
                            # this request will be sent when the object is
                            # ready.
                            continue
                    else:
                        get_resp = self.basic_service._get_object(
                            req.get, client_id)
                    resp = ray_client_pb2.DataResponse(get=get_resp)
                elif req_type == "put":
                    if not self.put_request_chunk_collector.add_chunk(
                            req, req.put):
                        # Put request still in progress
                        continue
                    put_resp = self.basic_service._put_object(
                        self.put_request_chunk_collector.data,
                        req.put.client_ref_id,
                        client_id,
                    )
                    self.put_request_chunk_collector.reset()
                    resp = ray_client_pb2.DataResponse(put=put_resp)
                elif req_type == "release":
                    released = []
                    for rel_id in req.release.ids:
                        rel = self.basic_service.release(client_id, rel_id)
                        released.append(rel)
                    resp = ray_client_pb2.DataResponse(
                        release=ray_client_pb2.ReleaseResponse(ok=released))
                elif req_type == "connection_info":
                    resp = ray_client_pb2.DataResponse(
                        connection_info=self._build_connection_response())
                elif req_type == "prep_runtime_env":
                    with self.clients_lock:
                        resp_prep = self.basic_service.PrepRuntimeEnv(
                            req.prep_runtime_env)
                        resp = ray_client_pb2.DataResponse(
                            prep_runtime_env=resp_prep)
                elif req_type == "connection_cleanup":
                    cleanup_requested = True
                    cleanup_resp = ray_client_pb2.ConnectionCleanupResponse()
                    resp = ray_client_pb2.DataResponse(
                        connection_cleanup=cleanup_resp)
                elif req_type == "acknowledge":
                    # Clean up acknowledged cache entries
                    response_cache.cleanup(req.acknowledge.req_id)
                    continue
                elif req_type == "task":
                    with self.clients_lock:
                        task = req.task
                        if not self.client_task_chunk_collector.add_chunk(
                                req, task):
                            # Not all serialized arguments have arrived
                            continue
                        arglist, kwargs = loads_from_client(
                            self.client_task_chunk_collector.data,
                            self.basic_service)
                        self.client_task_chunk_collector.reset()
                        resp_ticket = self.basic_service.Schedule(
                            req.task, arglist, kwargs, context)
                        resp = ray_client_pb2.DataResponse(
                            task_ticket=resp_ticket)
                elif req_type == "terminate":
                    with self.clients_lock:
                        response = self.basic_service.Terminate(
                            req.terminate, context)
                        resp = ray_client_pb2.DataResponse(terminate=response)
                elif req_type == "list_named_actors":
                    with self.clients_lock:
                        response = self.basic_service.ListNamedActors(
                            req.list_named_actors)
                        resp = ray_client_pb2.DataResponse(
                            list_named_actors=response)
                else:
                    raise Exception(f"Unreachable code: Request type "
                                    f"{req_type} not handled in Datapath")
                resp.req_id = req.req_id
                if _should_cache(req) and reconnect_enabled:
                    response_cache.update_cache(req.req_id, resp)
                yield resp
        except Exception as e:
            logger.exception("Error in data channel:")
            recoverable = _propagate_error_in_context(e, context)
            invalid_cache = response_cache.invalidate(e)
            if not recoverable or invalid_cache:
                context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
                # Connection isn't recoverable, skip cleanup
                cleanup_requested = True
        finally:
            logger.debug(f"Stream is broken with client {client_id}")
            queue_filler_thread.join(QUEUE_JOIN_SECONDS)
            if queue_filler_thread.is_alive():
                logger.error(
                    "Queue filler thread failed to join before timeout: {}".
                    format(QUEUE_JOIN_SECONDS))
            cleanup_delay = self.reconnect_grace_periods.get(client_id)
            if not cleanup_requested and cleanup_delay is not None:
                logger.debug("Cleanup wasn't requested, delaying cleanup by"
                             f"{cleanup_delay} seconds.")
                # Delay cleanup, since client may attempt a reconnect
                # Wait on the "stopped" event in case the grpc server is
                # stopped and we can clean up earlier.
                self.stopped.wait(timeout=cleanup_delay)
            else:
                logger.debug("Cleanup was requested, cleaning up immediately.")
            with self.clients_lock:
                if client_id not in self.client_last_seen:
                    logger.debug("Connection already cleaned up.")
                    # Some other connection has already cleaned up this
                    # this client's session. This can happen if the client
                    # reconnects and then gracefully shut's down immediately.
                    return
                last_seen = self.client_last_seen[client_id]
                if last_seen > start_time:
                    # The client successfully reconnected and updated
                    # last seen some time during the grace period
                    logger.debug("Client reconnected, skipping cleanup")
                    return
                # Either the client shut down gracefully, or the client
                # failed to reconnect within the grace period. Clean up
                # the connection.
                self.basic_service.release_all(client_id)
                del self.client_last_seen[client_id]
                if client_id in self.reconnect_grace_periods:
                    del self.reconnect_grace_periods[client_id]
                if client_id in self.response_caches:
                    del self.response_caches[client_id]
                self.num_clients -= 1
                logger.debug(f"Removed client {client_id}, "
                             f"remaining={self.num_clients}")

                # It's important to keep the Ray shutdown
                # within this locked context or else Ray could hang.
                # NOTE: it is strange to start ray in server.py but shut it
                # down here. Consider consolidating ray lifetime management.
                with disable_client_hook():
                    if self.num_clients == 0:
                        logger.debug("Shutting down ray.")
                        ray.shutdown()
Esempio n. 9
0
    def Datapath(self, request_iterator, context):
        cleanup_requested = False
        start_time = time.time()
        client_id = _get_client_id_from_context(context)
        if client_id == "":
            return
        reconnecting = _get_reconnecting_from_context(context)

        if reconnecting:
            with self.clients_lock:
                if client_id not in self.clients_last_seen:
                    # Client took too long to reconnect, session has already
                    # been cleaned up
                    context.set_code(grpc.StatusCode.NOT_FOUND)
                    context.set_details(
                        "Attempted to reconnect a session that has already "
                        "been cleaned up")
                    return
                self.clients_last_seen[client_id] = start_time
            server = self.proxy_manager._get_server_for_client(client_id)
            channel = self.proxy_manager.get_channel(client_id)
            # iterator doesn't need modification on reconnect
            new_iter = request_iterator
        else:
            # Create Placeholder *before* reading the first request.
            server = self.proxy_manager.create_specific_server(client_id)
            with self.clients_lock:
                self.clients_last_seen[client_id] = start_time
                self.num_clients += 1

        try:
            if not reconnecting:
                logger.info(f"New data connection from client {client_id}: ")
                init_req = next(request_iterator)
                with self.clients_lock:
                    self.reconnect_grace_periods[client_id] = \
                        init_req.init.reconnect_grace_period
                try:
                    modified_init_req, job_config = prepare_runtime_init_req(
                        init_req)
                    if not self.proxy_manager.start_specific_server(
                            client_id, job_config):
                        logger.error(
                            f"Server startup failed for client: {client_id}, "
                            f"using JobConfig: {job_config}!")
                        raise RuntimeError(
                            "Starting Ray client server failed. See "
                            f"ray_client_server_{server.port}.err for "
                            "detailed logs.")
                    channel = self.proxy_manager.get_channel(client_id)
                    if channel is None:
                        logger.error(f"Channel not found for {client_id}")
                        raise RuntimeError(
                            "Proxy failed to Connect to backend! Check "
                            "`ray_client_server.err` and "
                            f"`ray_client_server_{server.port}.err` on the "
                            "head node of the cluster for the relevant logs. "
                            "By default these are located at "
                            "/tmp/ray/session_latest/logs.")
                except Exception:
                    init_resp = ray_client_pb2.DataResponse(
                        init=ray_client_pb2.InitResponse(
                            ok=False, msg=traceback.format_exc()))
                    init_resp.req_id = init_req.req_id
                    yield init_resp
                    return None

                new_iter = chain([modified_init_req], request_iterator)

            stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
            metadata = [("client_id", client_id), ("reconnecting",
                                                   str(reconnecting))]
            resp_stream = stub.Datapath(new_iter, metadata=metadata)
            for resp in resp_stream:
                resp_type = resp.WhichOneof("type")
                if resp_type == "connection_cleanup":
                    # Specific server is skipping cleanup, proxier should too
                    cleanup_requested = True
                yield self.modify_connection_info_resp(resp)
        except Exception as e:
            logger.exception("Proxying Datapath failed!")
            # Propogate error through context
            recoverable = _propagate_error_in_context(e, context)
            if not recoverable:
                # Client shouldn't attempt to recover, clean up connection
                cleanup_requested = True
        finally:
            cleanup_delay = self.reconnect_grace_periods.get(client_id)
            if not cleanup_requested and cleanup_delay is not None:
                # Delay cleanup, since client may attempt a reconnect
                # Wait on stopped event in case the server closes and we
                # can clean up earlier
                self.stopped.wait(timeout=cleanup_delay)
            with self.clients_lock:
                if client_id not in self.clients_last_seen:
                    logger.info(f"{client_id} not found. Skipping clean up.")
                    # Connection has already been cleaned up
                    return
                last_seen = self.clients_last_seen[client_id]
                logger.info(
                    f"{client_id} last started stream at {last_seen}. Current "
                    f"stream started at {start_time}.")
                if last_seen > start_time:
                    logger.info("Client reconnected. Skipping cleanup.")
                    # Client has reconnected, don't clean up
                    return
                logger.debug(f"Client detached: {client_id}")
                self.num_clients -= 1
                del self.clients_last_seen[client_id]
                if client_id in self.reconnect_grace_periods:
                    del self.reconnect_grace_periods[client_id]
                server.set_result(None)
Esempio n. 10
0
    def Datapath(self, request_iterator, context):
        metadata = {k: v for k, v in context.invocation_metadata()}
        client_id = metadata["client_id"]
        if client_id == "":
            logger.error("Client connecting with no client_id")
            return
        logger.debug(f"New data connection from client {client_id}: ")
        accepted_connection = self._init(client_id, context)
        if not accepted_connection:
            return
        try:
            request_queue = Queue()
            queue_filler_thread = Thread(target=fill_queue,
                                         daemon=True,
                                         args=(request_iterator,
                                               request_queue))
            queue_filler_thread.start()
            """For non `async get` requests, this loop yields immediately
            For `async get` requests, this loop:
                 1) does not yield, it just continues
                 2) When the result is ready, it yields
            """
            for req in iter(request_queue.get, None):
                if isinstance(req, ray_client_pb2.DataResponse):
                    # Early shortcut if this is the result of an async get.
                    yield req
                    continue

                assert isinstance(req, ray_client_pb2.DataRequest)
                resp = None
                req_type = req.WhichOneof("type")
                if req_type == "init":
                    resp_init = self.basic_service.Init(req.init)
                    resp = ray_client_pb2.DataResponse(init=resp_init, )
                elif req_type == "get":
                    get_resp = None
                    if req.get.asynchronous:
                        get_resp = self.basic_service._async_get_object(
                            req.get, client_id, req.req_id, request_queue)
                        if get_resp is None:
                            # Skip sending a response for this request and
                            # continue to the next requst. The response for
                            # this request will be sent when the object is
                            # ready.
                            continue
                    else:
                        get_resp = self.basic_service._get_object(
                            req.get, client_id)
                    resp = ray_client_pb2.DataResponse(get=get_resp)
                elif req_type == "put":
                    put_resp = self.basic_service._put_object(
                        req.put, client_id)
                    resp = ray_client_pb2.DataResponse(put=put_resp)
                elif req_type == "release":
                    released = []
                    for rel_id in req.release.ids:
                        rel = self.basic_service.release(client_id, rel_id)
                        released.append(rel)
                    resp = ray_client_pb2.DataResponse(
                        release=ray_client_pb2.ReleaseResponse(ok=released))
                elif req_type == "connection_info":
                    resp = ray_client_pb2.DataResponse(
                        connection_info=self._build_connection_response())
                elif req_type == "prep_runtime_env":
                    with self.clients_lock:
                        resp_prep = self.basic_service.PrepRuntimeEnv(
                            req.prep_runtime_env)
                        resp = ray_client_pb2.DataResponse(
                            prep_runtime_env=resp_prep)
                else:
                    raise Exception(f"Unreachable code: Request type "
                                    f"{req_type} not handled in Datapath")
                resp.req_id = req.req_id
                yield resp
        except grpc.RpcError as e:
            logger.debug(f"Closing data channel: {e}")
        finally:
            logger.debug(f"Lost data connection from client {client_id}")
            self.basic_service.release_all(client_id)
            queue_filler_thread.join(QUEUE_JOIN_SECONDS)
            if queue_filler_thread.is_alive():
                logger.error(
                    "Queue filler thread failed to  join before timeout: {}".
                    format(QUEUE_JOIN_SECONDS))
            with self.clients_lock:
                # Could fail before client accounting happens
                self.num_clients -= 1
                logger.debug(f"Removed clients. {self.num_clients}")

                # It's important to keep the Ray shutdown
                # within this locked context or else Ray could hang.
                with disable_client_hook():
                    if self.num_clients == 0:
                        logger.debug("Shutting down ray.")
                        ray.shutdown()
Esempio n. 11
0
    def Datapath(self, request_iterator, context):
        metadata = {k: v for k, v in context.invocation_metadata()}
        client_id = metadata["client_id"]
        accepted_connection = False
        if client_id == "":
            logger.error("Client connecting with no client_id")
            return
        logger.debug(f"New data connection from client {client_id}: ")
        try:
            for req in request_iterator:
                resp = None
                req_type = req.WhichOneof("type")
                if req_type == "init":
                    resp = self._init(req.init, client_id)
                    if resp is None:
                        context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
                        return
                    logger.debug(f"Accepted data connection from {client_id}. "
                                 f"Total clients: {self.num_clients}")
                    accepted_connection = True
                else:
                    assert accepted_connection
                    if req_type == "get":
                        get_resp = self.basic_service._get_object(
                            req.get, client_id)
                        resp = ray_client_pb2.DataResponse(get=get_resp)
                    elif req_type == "put":
                        put_resp = self.basic_service._put_object(
                            req.put, client_id)
                        resp = ray_client_pb2.DataResponse(put=put_resp)
                    elif req_type == "release":
                        released = []
                        for rel_id in req.release.ids:
                            rel = self.basic_service.release(client_id, rel_id)
                            released.append(rel)
                        resp = ray_client_pb2.DataResponse(
                            release=ray_client_pb2.ReleaseResponse(
                                ok=released))
                    elif req_type == "connection_info":
                        resp = ray_client_pb2.DataResponse(
                            connection_info=self._build_connection_response())
                    elif req_type == "prep_runtime_env":
                        with self.clients_lock:
                            resp_prep = self.basic_service.PrepRuntimeEnv(
                                req.prep_runtime_env)
                            resp = ray_client_pb2.DataResponse(
                                prep_runtime_env=resp_prep)
                    else:
                        raise Exception(f"Unreachable code: Request type "
                                        f"{req_type} not handled in Datapath")
                resp.req_id = req.req_id
                yield resp
        except grpc.RpcError as e:
            logger.debug(f"Closing data channel: {e}")
        finally:
            logger.debug(f"Lost data connection from client {client_id}")
            self.basic_service.release_all(client_id)

            with self.clients_lock:
                if accepted_connection:
                    # Could fail before client accounting happens
                    self.num_clients -= 1
                    logger.debug(f"Removed clients. {self.num_clients}")

                # It's important to keep the Ray shutdown
                # within this locked context or else Ray could hang.
                with disable_client_hook():
                    if self.num_clients == 0:
                        logger.debug("Shutting down ray.")
                        ray.shutdown()
Esempio n. 12
0
    def Datapath(self, request_iterator, context):
        metadata = {k: v for k, v in context.invocation_metadata()}
        client_id = metadata["client_id"]
        accepted_connection = False
        if client_id == "":
            logger.error("Client connecting with no client_id")
            return
        logger.debug(f"New data connection from client {client_id}: ")
        try:
            with self.clients_lock:
                with disable_client_hook():
                    # It's important to keep the ray initialization call
                    # within this locked context or else Ray could hang.
                    if self.num_clients == 0 and not ray.is_initialized():
                        self.ray_connect_handler()
                threshold = int(CLIENT_SERVER_MAX_THREADS / 2)
                if self.num_clients >= threshold:
                    context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
                    logger.warning(
                        f"[Data Servicer]: Num clients {self.num_clients} "
                        f"has reached the threshold {threshold}. "
                        f"Rejecting client: {metadata['client_id']}. ")
                    if log_once("client_threshold"):
                        logger.warning(
                            "You can configure the client connection "
                            "threshold by setting the "
                            "RAY_CLIENT_SERVER_MAX_THREADS env var "
                            f"(currently set to {CLIENT_SERVER_MAX_THREADS}).")
                    return

                self.num_clients += 1
                logger.debug(f"Accepted data connection from {client_id}. "
                             f"Total clients: {self.num_clients}")
                accepted_connection = True
            for req in request_iterator:
                resp = None
                req_type = req.WhichOneof("type")
                if req_type == "get":
                    get_resp = self.basic_service._get_object(
                        req.get, client_id)
                    resp = ray_client_pb2.DataResponse(get=get_resp)
                elif req_type == "put":
                    put_resp = self.basic_service._put_object(
                        req.put, client_id)
                    resp = ray_client_pb2.DataResponse(put=put_resp)
                elif req_type == "release":
                    released = []
                    for rel_id in req.release.ids:
                        rel = self.basic_service.release(client_id, rel_id)
                        released.append(rel)
                    resp = ray_client_pb2.DataResponse(
                        release=ray_client_pb2.ReleaseResponse(ok=released))
                elif req_type == "connection_info":
                    resp = ray_client_pb2.DataResponse(
                        connection_info=self._build_connection_response())
                else:
                    raise Exception(f"Unreachable code: Request type "
                                    f"{req_type} not handled in Datapath")
                resp.req_id = req.req_id
                yield resp
        except grpc.RpcError as e:
            logger.debug(f"Closing data channel: {e}")
        finally:
            logger.debug(f"Lost data connection from client {client_id}")
            self.basic_service.release_all(client_id)

            with self.clients_lock:
                if accepted_connection:
                    # Could fail before client accounting happens
                    self.num_clients -= 1
                    logger.debug(f"Removed clients. {self.num_clients}")

                # It's important to keep the Ray shutdown
                # within this locked context or else Ray could hang.
                with disable_client_hook():
                    if self.num_clients == 0:
                        logger.debug("Shutting down ray.")
                        ray.shutdown()