Esempio n. 1
0
File: server.py Progetto: rlan/ray
 def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse:
     with disable_client_hook():
         already_exists = ray.experimental.internal_kv._internal_kv_put(
             request.key, request.value, overwrite=request.overwrite)
     return ray_client_pb2.KVPutResponse(already_exists=already_exists)
Esempio n. 2
0
File: server.py Progetto: rlan/ray
 def KVDel(self, request, context=None) -> ray_client_pb2.KVDelResponse:
     with disable_client_hook():
         ray.experimental.internal_kv._internal_kv_del(request.key)
     return ray_client_pb2.KVDelResponse()
Esempio n. 3
0
    def Datapath(self, request_iterator, context):
        start_time = time.time()
        # set to True if client shuts down gracefully
        cleanup_requested = False
        metadata = {k: v for k, v in context.invocation_metadata()}
        client_id = metadata.get("client_id")
        if client_id is None:
            logger.error("Client connecting with no client_id")
            return
        logger.debug(f"New data connection from client {client_id}: ")
        accepted_connection = self._init(client_id, context, start_time)
        response_cache = self.response_caches[client_id]
        # Set to False if client requests a reconnect grace period of 0
        reconnect_enabled = True
        if not accepted_connection:
            return
        try:
            request_queue = Queue()
            queue_filler_thread = Thread(
                target=fill_queue,
                daemon=True,
                args=(request_iterator, request_queue))
            queue_filler_thread.start()
            """For non `async get` requests, this loop yields immediately
            For `async get` requests, this loop:
                 1) does not yield, it just continues
                 2) When the result is ready, it yields
            """
            for req in iter(request_queue.get, None):
                if isinstance(req, ray_client_pb2.DataResponse):
                    # Early shortcut if this is the result of an async get.
                    yield req
                    continue

                assert isinstance(req, ray_client_pb2.DataRequest)
                if _should_cache(req) and reconnect_enabled:
                    cached_resp = response_cache.check_cache(req.req_id)
                    if isinstance(cached_resp, Exception):
                        # Cache state is invalid, raise exception
                        raise cached_resp
                    if cached_resp is not None:
                        yield cached_resp
                        continue

                resp = None
                req_type = req.WhichOneof("type")
                if req_type == "init":
                    resp_init = self.basic_service.Init(req.init)
                    resp = ray_client_pb2.DataResponse(init=resp_init, )
                    with self.clients_lock:
                        self.reconnect_grace_periods[client_id] = \
                            req.init.reconnect_grace_period
                        if req.init.reconnect_grace_period == 0:
                            reconnect_enabled = False

                elif req_type == "get":
                    if req.get.asynchronous:
                        get_resp = self.basic_service._async_get_object(
                            req.get, client_id, req.req_id, request_queue)
                        if get_resp is None:
                            # Skip sending a response for this request and
                            # continue to the next requst. The response for
                            # this request will be sent when the object is
                            # ready.
                            continue
                    else:
                        get_resp = self.basic_service._get_object(
                            req.get, client_id)
                    resp = ray_client_pb2.DataResponse(get=get_resp)
                elif req_type == "put":
                    put_resp = self.basic_service._put_object(
                        req.put, client_id)
                    resp = ray_client_pb2.DataResponse(put=put_resp)
                elif req_type == "release":
                    released = []
                    for rel_id in req.release.ids:
                        rel = self.basic_service.release(client_id, rel_id)
                        released.append(rel)
                    resp = ray_client_pb2.DataResponse(
                        release=ray_client_pb2.ReleaseResponse(ok=released))
                elif req_type == "connection_info":
                    resp = ray_client_pb2.DataResponse(
                        connection_info=self._build_connection_response())
                elif req_type == "prep_runtime_env":
                    with self.clients_lock:
                        resp_prep = self.basic_service.PrepRuntimeEnv(
                            req.prep_runtime_env)
                        resp = ray_client_pb2.DataResponse(
                            prep_runtime_env=resp_prep)
                elif req_type == "connection_cleanup":
                    cleanup_requested = True
                    cleanup_resp = ray_client_pb2.ConnectionCleanupResponse()
                    resp = ray_client_pb2.DataResponse(
                        connection_cleanup=cleanup_resp)
                elif req_type == "acknowledge":
                    # Clean up acknowledged cache entries
                    response_cache.cleanup(req.acknowledge.req_id)
                    continue
                else:
                    raise Exception(f"Unreachable code: Request type "
                                    f"{req_type} not handled in Datapath")
                resp.req_id = req.req_id
                if _should_cache(req) and reconnect_enabled:
                    response_cache.update_cache(req.req_id, resp)
                yield resp
        except Exception as e:
            logger.exception("Error in data channel:")
            recoverable = _propagate_error_in_context(e, context)
            invalid_cache = response_cache.invalidate(e)
            if not recoverable or invalid_cache:
                context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
                # Connection isn't recoverable, skip cleanup
                cleanup_requested = True
        finally:
            logger.debug(f"Stream is broken with client {client_id}")
            queue_filler_thread.join(QUEUE_JOIN_SECONDS)
            if queue_filler_thread.is_alive():
                logger.error(
                    "Queue filler thread failed to join before timeout: {}".
                    format(QUEUE_JOIN_SECONDS))
            cleanup_delay = self.reconnect_grace_periods.get(client_id)
            if not cleanup_requested and cleanup_delay is not None:
                logger.debug("Cleanup wasn't requested, delaying cleanup by"
                             f"{cleanup_delay} seconds.")
                # Delay cleanup, since client may attempt a reconnect
                # Wait on the "stopped" event in case the grpc server is
                # stopped and we can clean up earlier.
                self.stopped.wait(timeout=cleanup_delay)
            else:
                logger.debug("Cleanup was requested, cleaning up immediately.")
            with self.clients_lock:
                if client_id not in self.client_last_seen:
                    logger.debug("Connection already cleaned up.")
                    # Some other connection has already cleaned up this
                    # this client's session. This can happen if the client
                    # reconnects and then gracefully shut's down immediately.
                    return
                last_seen = self.client_last_seen[client_id]
                if last_seen > start_time:
                    # The client successfully reconnected and updated
                    # last seen some time during the grace period
                    logger.debug("Client reconnected, skipping cleanup")
                    return
                # Either the client shut down gracefully, or the client
                # failed to reconnect within the grace period. Clean up
                # the connection.
                self.basic_service.release_all(client_id)
                del self.client_last_seen[client_id]
                if client_id in self.reconnect_grace_periods:
                    del self.reconnect_grace_periods[client_id]
                if client_id in self.response_caches:
                    del self.response_caches[client_id]
                self.num_clients -= 1
                logger.debug(f"Removed client {client_id}, "
                             f"remaining={self.num_clients}")

                # It's important to keep the Ray shutdown
                # within this locked context or else Ray could hang.
                # NOTE: it is strange to start ray in server.py but shut it
                # down here. Consider consolidating ray lifetime management.
                with disable_client_hook():
                    if self.num_clients == 0:
                        logger.debug("Shutting down ray.")
                        ray.shutdown()
Esempio n. 4
0
 def default_connect_handler(job_config: JobConfig = None):
     with disable_client_hook():
         if not ray.is_initialized():
             return ray.init(job_config=job_config)
Esempio n. 5
0
    def Datapath(self, request_iterator, context):
        metadata = {k: v for k, v in context.invocation_metadata()}
        client_id = metadata["client_id"]
        accepted_connection = False
        if client_id == "":
            logger.error("Client connecting with no client_id")
            return
        logger.debug(f"New data connection from client {client_id}: ")
        try:
            for req in request_iterator:
                resp = None
                req_type = req.WhichOneof("type")
                if req_type == "init":
                    resp = self._init(req.init, client_id)
                    if resp is None:
                        context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
                        return
                    logger.debug(f"Accepted data connection from {client_id}. "
                                 f"Total clients: {self.num_clients}")
                    accepted_connection = True
                else:
                    assert accepted_connection
                    if req_type == "get":
                        get_resp = self.basic_service._get_object(
                            req.get, client_id)
                        resp = ray_client_pb2.DataResponse(get=get_resp)
                    elif req_type == "put":
                        put_resp = self.basic_service._put_object(
                            req.put, client_id)
                        resp = ray_client_pb2.DataResponse(put=put_resp)
                    elif req_type == "release":
                        released = []
                        for rel_id in req.release.ids:
                            rel = self.basic_service.release(client_id, rel_id)
                            released.append(rel)
                        resp = ray_client_pb2.DataResponse(
                            release=ray_client_pb2.ReleaseResponse(
                                ok=released))
                    elif req_type == "connection_info":
                        resp = ray_client_pb2.DataResponse(
                            connection_info=self._build_connection_response())
                    elif req_type == "prep_runtime_env":
                        with self.clients_lock:
                            resp_prep = self.basic_service.PrepRuntimeEnv(
                                req.prep_runtime_env)
                            resp = ray_client_pb2.DataResponse(
                                prep_runtime_env=resp_prep)
                    else:
                        raise Exception(f"Unreachable code: Request type "
                                        f"{req_type} not handled in Datapath")
                resp.req_id = req.req_id
                yield resp
        except grpc.RpcError as e:
            logger.debug(f"Closing data channel: {e}")
        finally:
            logger.debug(f"Lost data connection from client {client_id}")
            self.basic_service.release_all(client_id)

            with self.clients_lock:
                if accepted_connection:
                    # Could fail before client accounting happens
                    self.num_clients -= 1
                    logger.debug(f"Removed clients. {self.num_clients}")

                # It's important to keep the Ray shutdown
                # within this locked context or else Ray could hang.
                with disable_client_hook():
                    if self.num_clients == 0:
                        logger.debug("Shutting down ray.")
                        ray.shutdown()
Esempio n. 6
0
 def disable():
     with disable_client_hook():
         q.put(client_mode_should_convert(auto_init=True))
         lock.acquire()
     q.put(client_mode_should_convert(auto_init=True))
Esempio n. 7
0
 def KVExists(self, request,
              context=None) -> ray_client_pb2.KVExistsResponse:
     with disable_client_hook():
         exists = ray.experimental.internal_kv._internal_kv_exists(
             request.key)
     return ray_client_pb2.KVExistsResponse(exists=exists)
Esempio n. 8
0
File: server.py Progetto: rlan/ray
def shutdown_with_server(server, _exiting_interpreter=False):
    server.stop(1)
    with disable_client_hook():
        ray.shutdown(_exiting_interpreter)
Esempio n. 9
0
    def _async_get_object(
        self,
        request: ray_client_pb2.GetRequest,
        client_id: str,
        req_id: int,
        result_queue: queue.Queue,
        context=None,
    ) -> Optional[ray_client_pb2.GetResponse]:
        """Attempts to schedule a callback to push the GetResponse to the
        main loop when the desired object is ready. If there is some failure
        in scheduling, a GetResponse will be immediately returned.
        """
        if len(request.ids) != 1:
            raise ValueError("Async get() must have exactly 1 Object ID. "
                             f"Actual: {request}")
        rid = request.ids[0]
        ref = self.object_refs[client_id].get(rid, None)
        if not ref:
            return ray_client_pb2.GetResponse(
                valid=False,
                error=cloudpickle.dumps(
                    ValueError(f"ClientObjectRef with id {rid} not found for "
                               f"client {client_id}")),
            )
        try:
            logger.debug("async get: %s" % ref)
            with disable_client_hook():

                def send_get_response(result: Any) -> None:
                    """Pushes GetResponses to the main DataPath loop to send
                    to the client. This is called when the object is ready
                    on the server side."""
                    try:
                        serialized = dumps_from_server(result, client_id, self)
                        total_size = len(serialized)
                        assert total_size > 0, "Serialized object cannot be zero bytes"
                        total_chunks = math.ceil(total_size /
                                                 OBJECT_TRANSFER_CHUNK_SIZE)
                        for chunk_id in range(request.start_chunk_id,
                                              total_chunks):
                            start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
                            end = min(total_size, (chunk_id + 1) *
                                      OBJECT_TRANSFER_CHUNK_SIZE)
                            get_resp = ray_client_pb2.GetResponse(
                                valid=True,
                                data=serialized[start:end],
                                chunk_id=chunk_id,
                                total_chunks=total_chunks,
                                total_size=total_size,
                            )
                            chunk_resp = ray_client_pb2.DataResponse(
                                get=get_resp, req_id=req_id)
                            result_queue.put(chunk_resp)
                    except Exception as exc:
                        get_resp = ray_client_pb2.GetResponse(
                            valid=False, error=cloudpickle.dumps(exc))
                        resp = ray_client_pb2.DataResponse(get=get_resp,
                                                           req_id=req_id)
                        result_queue.put(resp)

                ref._on_completed(send_get_response)
                return None

        except Exception as e:
            return ray_client_pb2.GetResponse(valid=False,
                                              error=cloudpickle.dumps(e))
Esempio n. 10
0
 def ray_connect_handler(job_config=None, **ray_init_kwargs):
     from ray._private.client_mode_hook import disable_client_hook
     with disable_client_hook():
         import ray as real_ray
         if not real_ray.is_initialized():
             real_ray.init(**ray_config)
Esempio n. 11
0
    def Datapath(self, request_iterator, context):
        metadata = {k: v for k, v in context.invocation_metadata()}
        client_id = metadata["client_id"]
        if client_id == "":
            logger.error("Client connecting with no client_id")
            return
        logger.debug(f"New data connection from client {client_id}: ")
        accepted_connection = self._init(client_id, context)
        if not accepted_connection:
            return
        try:
            request_queue = Queue()
            queue_filler_thread = Thread(target=fill_queue,
                                         daemon=True,
                                         args=(request_iterator,
                                               request_queue))
            queue_filler_thread.start()
            """For non `async get` requests, this loop yields immediately
            For `async get` requests, this loop:
                 1) does not yield, it just continues
                 2) When the result is ready, it yields
            """
            for req in iter(request_queue.get, None):
                if isinstance(req, ray_client_pb2.DataResponse):
                    # Early shortcut if this is the result of an async get.
                    yield req
                    continue

                assert isinstance(req, ray_client_pb2.DataRequest)
                resp = None
                req_type = req.WhichOneof("type")
                if req_type == "init":
                    resp_init = self.basic_service.Init(req.init)
                    resp = ray_client_pb2.DataResponse(init=resp_init, )
                elif req_type == "get":
                    get_resp = None
                    if req.get.asynchronous:
                        get_resp = self.basic_service._async_get_object(
                            req.get, client_id, req.req_id, request_queue)
                        if get_resp is None:
                            # Skip sending a response for this request and
                            # continue to the next requst. The response for
                            # this request will be sent when the object is
                            # ready.
                            continue
                    else:
                        get_resp = self.basic_service._get_object(
                            req.get, client_id)
                    resp = ray_client_pb2.DataResponse(get=get_resp)
                elif req_type == "put":
                    put_resp = self.basic_service._put_object(
                        req.put, client_id)
                    resp = ray_client_pb2.DataResponse(put=put_resp)
                elif req_type == "release":
                    released = []
                    for rel_id in req.release.ids:
                        rel = self.basic_service.release(client_id, rel_id)
                        released.append(rel)
                    resp = ray_client_pb2.DataResponse(
                        release=ray_client_pb2.ReleaseResponse(ok=released))
                elif req_type == "connection_info":
                    resp = ray_client_pb2.DataResponse(
                        connection_info=self._build_connection_response())
                elif req_type == "prep_runtime_env":
                    with self.clients_lock:
                        resp_prep = self.basic_service.PrepRuntimeEnv(
                            req.prep_runtime_env)
                        resp = ray_client_pb2.DataResponse(
                            prep_runtime_env=resp_prep)
                else:
                    raise Exception(f"Unreachable code: Request type "
                                    f"{req_type} not handled in Datapath")
                resp.req_id = req.req_id
                yield resp
        except grpc.RpcError as e:
            logger.debug(f"Closing data channel: {e}")
        finally:
            logger.debug(f"Lost data connection from client {client_id}")
            self.basic_service.release_all(client_id)
            queue_filler_thread.join(QUEUE_JOIN_SECONDS)
            if queue_filler_thread.is_alive():
                logger.error(
                    "Queue filler thread failed to  join before timeout: {}".
                    format(QUEUE_JOIN_SECONDS))
            with self.clients_lock:
                # Could fail before client accounting happens
                self.num_clients -= 1
                logger.debug(f"Removed clients. {self.num_clients}")

                # It's important to keep the Ray shutdown
                # within this locked context or else Ray could hang.
                with disable_client_hook():
                    if self.num_clients == 0:
                        logger.debug("Shutting down ray.")
                        ray.shutdown()
Esempio n. 12
0
 def default_connect_handler():
     with disable_client_hook():
         if not ray.is_initialized():
             return ray.init()
Esempio n. 13
0
    def add_node(self, wait=True, **node_args):
        """Adds a node to the local Ray Cluster.

        All nodes are by default started with the following settings:
            cleanup=True,
            num_cpus=1,
            object_store_memory=150 * 1024 * 1024  # 150 MiB

        Args:
            wait (bool): Whether to wait until the node is alive.
            node_args: Keyword arguments used in `start_ray_head` and
                `start_ray_node`. Overrides defaults.

        Returns:
            Node object of the added Ray node.
        """
        default_kwargs = {
            "num_cpus": 1,
            "num_gpus": 0,
            "object_store_memory": 150 * 1024 * 1024,  # 150 MiB
            "min_worker_port": 0,
            "max_worker_port": 0,
            "dashboard_port": None,
        }
        ray_params = ray._private.parameter.RayParams(**node_args)
        ray_params.update_if_absent(**default_kwargs)
        with disable_client_hook():
            if self.head_node is None:
                node = ray.node.Node(ray_params,
                                     head=True,
                                     shutdown_at_exit=self._shutdown_at_exit,
                                     spawn_reaper=self._shutdown_at_exit)
                self.head_node = node
                self.redis_address = self.head_node.redis_address
                self.redis_password = node_args.get(
                    "redis_password", ray_constants.REDIS_DEFAULT_PASSWORD)
                self.webui_url = self.head_node.webui_url
                # Init global state accessor when creating head node.
                if use_gcs_for_bootstrap():
                    gcs_options = GcsClientOptions.from_gcs_address(
                        node.gcs_address)
                else:
                    gcs_options = GcsClientOptions.from_redis_address(
                        self.redis_address, self.redis_password)
                self.global_state._initialize_global_state(gcs_options)
            else:
                ray_params.update_if_absent(redis_address=self.redis_address)
                ray_params.update_if_absent(gcs_address=self.gcs_address)
                # We only need one log monitor per physical node.
                ray_params.update_if_absent(include_log_monitor=False)
                # Let grpc pick a port.
                ray_params.update_if_absent(node_manager_port=0)

                node = ray.node.Node(ray_params,
                                     head=False,
                                     shutdown_at_exit=self._shutdown_at_exit,
                                     spawn_reaper=self._shutdown_at_exit)
                self.worker_nodes.add(node)

            if wait:
                # Wait for the node to appear in the client table. We do this
                # so that the nodes appears in the client table in the order
                # that the corresponding calls to add_node were made. We do
                # this because in the tests we assume that the driver is
                # connected to the first node that is added.
                self._wait_for_node(node)

        return node
Esempio n. 14
0
File: server.py Progetto: rlan/ray
 def KVList(self, request, context=None) -> ray_client_pb2.KVListResponse:
     with disable_client_hook():
         keys = ray.experimental.internal_kv._internal_kv_list(
             request.prefix)
     return ray_client_pb2.KVListResponse(keys=keys)
Esempio n. 15
0
def dumps_from_client(obj: Any, client_id: str, protocol=None) -> bytes:
    with disable_client_hook():
        with io.BytesIO() as file:
            cp = ClientPickler(client_id, file, protocol=protocol)
            cp.dump(obj)
            return file.getvalue()
Esempio n. 16
0
File: server.py Progetto: rlan/ray
 def default_connect_handler(job_config: JobConfig = None,
                             **ray_init_kwargs: Dict[str, Any]):
     with disable_client_hook():
         if not ray.is_initialized():
             return ray.init(job_config=job_config, **ray_init_kwargs)
Esempio n. 17
0
def init_and_serve(connection_str, *args, **kwargs):
    with disable_client_hook():
        # Disable client mode inside the worker's environment
        info = ray.init(*args, **kwargs)
    server_handle = serve(connection_str)
    return (server_handle, info)
Esempio n. 18
0
 def KVGet(self, request, context=None) -> ray_client_pb2.KVGetResponse:
     with disable_client_hook():
         value = ray.experimental.internal_kv._internal_kv_get(request.key)
     return ray_client_pb2.KVGetResponse(value=value)
Esempio n. 19
0
    def Datapath(self, request_iterator, context):
        metadata = {k: v for k, v in context.invocation_metadata()}
        client_id = metadata["client_id"]
        accepted_connection = False
        if client_id == "":
            logger.error("Client connecting with no client_id")
            return
        logger.debug(f"New data connection from client {client_id}: ")
        try:
            with self.clients_lock:
                with disable_client_hook():
                    # It's important to keep the ray initialization call
                    # within this locked context or else Ray could hang.
                    if self.num_clients == 0 and not ray.is_initialized():
                        self.ray_connect_handler()
                threshold = int(CLIENT_SERVER_MAX_THREADS / 2)
                if self.num_clients >= threshold:
                    context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
                    logger.warning(
                        f"[Data Servicer]: Num clients {self.num_clients} "
                        f"has reached the threshold {threshold}. "
                        f"Rejecting client: {metadata['client_id']}. ")
                    if log_once("client_threshold"):
                        logger.warning(
                            "You can configure the client connection "
                            "threshold by setting the "
                            "RAY_CLIENT_SERVER_MAX_THREADS env var "
                            f"(currently set to {CLIENT_SERVER_MAX_THREADS}).")
                    return

                self.num_clients += 1
                logger.debug(f"Accepted data connection from {client_id}. "
                             f"Total clients: {self.num_clients}")
                accepted_connection = True
            for req in request_iterator:
                resp = None
                req_type = req.WhichOneof("type")
                if req_type == "get":
                    get_resp = self.basic_service._get_object(
                        req.get, client_id)
                    resp = ray_client_pb2.DataResponse(get=get_resp)
                elif req_type == "put":
                    put_resp = self.basic_service._put_object(
                        req.put, client_id)
                    resp = ray_client_pb2.DataResponse(put=put_resp)
                elif req_type == "release":
                    released = []
                    for rel_id in req.release.ids:
                        rel = self.basic_service.release(client_id, rel_id)
                        released.append(rel)
                    resp = ray_client_pb2.DataResponse(
                        release=ray_client_pb2.ReleaseResponse(ok=released))
                elif req_type == "connection_info":
                    resp = ray_client_pb2.DataResponse(
                        connection_info=self._build_connection_response())
                else:
                    raise Exception(f"Unreachable code: Request type "
                                    f"{req_type} not handled in Datapath")
                resp.req_id = req.req_id
                yield resp
        except grpc.RpcError as e:
            logger.debug(f"Closing data channel: {e}")
        finally:
            logger.debug(f"Lost data connection from client {client_id}")
            self.basic_service.release_all(client_id)

            with self.clients_lock:
                if accepted_connection:
                    # Could fail before client accounting happens
                    self.num_clients -= 1
                    logger.debug(f"Removed clients. {self.num_clients}")

                # It's important to keep the Ray shutdown
                # within this locked context or else Ray could hang.
                with disable_client_hook():
                    if self.num_clients == 0:
                        logger.debug("Shutting down ray.")
                        ray.shutdown()