Exemple #1
0
 def Init(self, request, context=None) -> ray_client_pb2.InitResponse:
     import pickle
     if request.job_config:
         job_config = pickle.loads(request.job_config)
         job_config.client_job = True
     else:
         job_config = None
     current_job_config = None
     with disable_client_hook():
         if ray.is_initialized():
             worker = ray.worker.global_worker
             current_job_config = worker.core_worker.get_job_config()
         else:
             self.ray_connect_handler(job_config)
     if job_config is None:
         return ray_client_pb2.InitResponse()
     job_config = job_config.get_proto_job_config()
     # If the server has been initialized, we need to compare whether the
     # runtime env is compatible.
     if current_job_config and \
        set(job_config.runtime_env.uris) != set(
             current_job_config.runtime_env.uris) and \
             len(job_config.runtime_env.uris) > 0:
         return ray_client_pb2.InitResponse(
             ok=False,
             msg="Runtime environment doesn't match "
             f"request one {job_config.runtime_env.uris} "
             f"current one {current_job_config.runtime_env.uris}")
     return ray_client_pb2.InitResponse(ok=True)
Exemple #2
0
    def Init(self,
             request: ray_client_pb2.InitRequest,
             context=None) -> ray_client_pb2.InitResponse:
        if request.job_config:
            job_config = pickle.loads(request.job_config)
            job_config.client_job = True
        else:
            job_config = None
        current_job_config = None
        with disable_client_hook():
            if ray.is_initialized():
                worker = ray._private.worker.global_worker
                current_job_config = worker.core_worker.get_job_config()
            else:
                extra_kwargs = json.loads(request.ray_init_kwargs or "{}")
                try:
                    self.ray_connect_handler(job_config, **extra_kwargs)
                except Exception as e:
                    logger.exception("Running Ray Init failed:")
                    return ray_client_pb2.InitResponse(
                        ok=False,
                        msg="Call to `ray.init()` on the server "
                        f"failed with: {e}",
                    )
        if job_config is None:
            return ray_client_pb2.InitResponse(ok=True)

        # NOTE(edoakes): this code should not be necessary anymore because we
        # only allow a single client/job per server. There is an existing test
        # that tests the behavior of multiple clients with the same job config
        # connecting to one server (test_client_init.py::test_num_clients),
        # so I'm leaving it here for now.
        job_config = job_config.get_proto_job_config()
        # If the server has been initialized, we need to compare whether the
        # runtime env is compatible.
        if current_job_config:
            job_uris = set(job_config.runtime_env_info.uris.working_dir_uri)
            job_uris.update(job_config.runtime_env_info.uris.py_modules_uris)
            current_job_uris = set(
                current_job_config.runtime_env_info.uris.working_dir_uri)
            current_job_uris.update(
                current_job_config.runtime_env_info.uris.py_modules_uris)
            if job_uris != current_job_uris and len(job_uris) > 0:
                return ray_client_pb2.InitResponse(
                    ok=False,
                    msg="Runtime environment doesn't match "
                    f"request one {job_config.runtime_env_info.uris} "
                    f"current one {current_job_config.runtime_env_info.uris}",
                )
        return ray_client_pb2.InitResponse(ok=True)
Exemple #3
0
    def Datapath(self, request_iterator, context):
        client_id = _get_client_id_from_context(context)
        if client_id == "":
            return

        # Create Placeholder *before* reading the first request.
        server = self.proxy_manager.create_specific_server(client_id)
        try:
            with self.clients_lock:
                self.num_clients += 1

            logger.info(f"New data connection from client {client_id}: ")
            init_req = next(request_iterator)
            try:
                modified_init_req, job_config = prepare_runtime_init_req(
                    init_req)
                if not self.proxy_manager.start_specific_server(
                        client_id, job_config):
                    logger.error(
                        f"Server startup failed for client: {client_id}, "
                        f"using JobConfig: {job_config}!")
                    raise RuntimeError(
                        "Starting Ray client server failed. This is most "
                        "likely because the runtime_env failed to be "
                        "installed. See ray_client_server_[port].err on the "
                        "head node of the cluster for the relevant logs.")
                channel = self.proxy_manager.get_channel(client_id)
                if channel is None:
                    logger.error(f"Channel not found for {client_id}")
                    raise RuntimeError(
                        "Proxy failed to Connect to backend! Check "
                        "`ray_client_server.err` on the cluster.")
                stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
            except Exception:
                init_resp = ray_client_pb2.DataResponse(
                    init=ray_client_pb2.InitResponse(
                        ok=False, msg=traceback.format_exc()))
                init_resp.req_id = init_req.req_id
                yield init_resp
                return None

            new_iter = chain([modified_init_req], request_iterator)
            resp_stream = stub.Datapath(
                new_iter, metadata=[("client_id", client_id)])
            for resp in resp_stream:
                yield self.modify_connection_info_resp(resp)
        except Exception:
            logger.exception("Proxying Datapath failed!")
        finally:
            server.set_result(None)
            with self.clients_lock:
                logger.debug(f"Client detached: {client_id}")
                self.num_clients -= 1
Exemple #4
0
 def Init(self,
          request: ray_client_pb2.InitRequest,
          context=None) -> ray_client_pb2.InitResponse:
     if request.job_config:
         job_config = pickle.loads(request.job_config)
         job_config.client_job = True
     else:
         job_config = None
     current_job_config = None
     with disable_client_hook():
         if ray.is_initialized():
             worker = ray.worker.global_worker
             current_job_config = worker.core_worker.get_job_config()
         else:
             extra_kwargs = json.loads(request.ray_init_kwargs or "{}")
             try:
                 self.ray_connect_handler(job_config, **extra_kwargs)
             except Exception as e:
                 logger.exception("Running Ray Init failed:")
                 return ray_client_pb2.InitResponse(
                     ok=False,
                     msg="Call to `ray.init()` on the server "
                     f"failed with: {e}")
     if job_config is None:
         return ray_client_pb2.InitResponse(ok=True)
     job_config = job_config.get_proto_job_config()
     # If the server has been initialized, we need to compare whether the
     # runtime env is compatible.
     if current_job_config and \
        set(job_config.runtime_env.uris) != set(
             current_job_config.runtime_env.uris) and \
             len(job_config.runtime_env.uris) > 0:
         return ray_client_pb2.InitResponse(
             ok=False,
             msg="Runtime environment doesn't match "
             f"request one {job_config.runtime_env.uris} "
             f"current one {current_job_config.runtime_env.uris}")
     return ray_client_pb2.InitResponse(ok=True)
Exemple #5
0
    def Datapath(self, request_iterator, context):
        cleanup_requested = False
        start_time = time.time()
        client_id = _get_client_id_from_context(context)
        if client_id == "":
            return
        reconnecting = _get_reconnecting_from_context(context)

        if reconnecting:
            with self.clients_lock:
                if client_id not in self.clients_last_seen:
                    # Client took too long to reconnect, session has already
                    # been cleaned up
                    context.set_code(grpc.StatusCode.NOT_FOUND)
                    context.set_details(
                        "Attempted to reconnect a session that has already "
                        "been cleaned up")
                    return
                self.clients_last_seen[client_id] = start_time
            server = self.proxy_manager._get_server_for_client(client_id)
            channel = self.proxy_manager.get_channel(client_id)
            # iterator doesn't need modification on reconnect
            new_iter = request_iterator
        else:
            # Create Placeholder *before* reading the first request.
            server = self.proxy_manager.create_specific_server(client_id)
            with self.clients_lock:
                self.clients_last_seen[client_id] = start_time
                self.num_clients += 1

        try:
            if not reconnecting:
                logger.info(f"New data connection from client {client_id}: ")
                init_req = next(request_iterator)
                with self.clients_lock:
                    self.reconnect_grace_periods[client_id] = \
                        init_req.init.reconnect_grace_period
                try:
                    modified_init_req, job_config = prepare_runtime_init_req(
                        init_req)
                    if not self.proxy_manager.start_specific_server(
                            client_id, job_config):
                        logger.error(
                            f"Server startup failed for client: {client_id}, "
                            f"using JobConfig: {job_config}!")
                        raise RuntimeError(
                            "Starting Ray client server failed. See "
                            f"ray_client_server_{server.port}.err for "
                            "detailed logs.")
                    channel = self.proxy_manager.get_channel(client_id)
                    if channel is None:
                        logger.error(f"Channel not found for {client_id}")
                        raise RuntimeError(
                            "Proxy failed to Connect to backend! Check "
                            "`ray_client_server.err` and "
                            f"`ray_client_server_{server.port}.err` on the "
                            "head node of the cluster for the relevant logs. "
                            "By default these are located at "
                            "/tmp/ray/session_latest/logs.")
                except Exception:
                    init_resp = ray_client_pb2.DataResponse(
                        init=ray_client_pb2.InitResponse(
                            ok=False, msg=traceback.format_exc()))
                    init_resp.req_id = init_req.req_id
                    yield init_resp
                    return None

                new_iter = chain([modified_init_req], request_iterator)

            stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
            metadata = [("client_id", client_id), ("reconnecting",
                                                   str(reconnecting))]
            resp_stream = stub.Datapath(new_iter, metadata=metadata)
            for resp in resp_stream:
                resp_type = resp.WhichOneof("type")
                if resp_type == "connection_cleanup":
                    # Specific server is skipping cleanup, proxier should too
                    cleanup_requested = True
                yield self.modify_connection_info_resp(resp)
        except Exception as e:
            logger.exception("Proxying Datapath failed!")
            # Propogate error through context
            recoverable = _propagate_error_in_context(e, context)
            if not recoverable:
                # Client shouldn't attempt to recover, clean up connection
                cleanup_requested = True
        finally:
            cleanup_delay = self.reconnect_grace_periods.get(client_id)
            if not cleanup_requested and cleanup_delay is not None:
                # Delay cleanup, since client may attempt a reconnect
                # Wait on stopped event in case the server closes and we
                # can clean up earlier
                self.stopped.wait(timeout=cleanup_delay)
            with self.clients_lock:
                if client_id not in self.clients_last_seen:
                    logger.info(f"{client_id} not found. Skipping clean up.")
                    # Connection has already been cleaned up
                    return
                last_seen = self.clients_last_seen[client_id]
                logger.info(
                    f"{client_id} last started stream at {last_seen}. Current "
                    f"stream started at {start_time}.")
                if last_seen > start_time:
                    logger.info("Client reconnected. Skipping cleanup.")
                    # Client has reconnected, don't clean up
                    return
                logger.debug(f"Client detached: {client_id}")
                self.num_clients -= 1
                del self.clients_last_seen[client_id]
                if client_id in self.reconnect_grace_periods:
                    del self.reconnect_grace_periods[client_id]
                server.set_result(None)