Esempio n. 1
0
 def _data_main(self) -> None:
     reconnecting = False
     try:
         while not self.client_worker._in_shutdown:
             stub = ray_client_pb2_grpc.RayletDataStreamerStub(
                 self.client_worker.channel)
             metadata = self._metadata + \
                 [("reconnecting", str(reconnecting))]
             resp_stream = stub.Datapath(iter(self.request_queue.get, None),
                                         metadata=metadata,
                                         wait_for_ready=True)
             try:
                 for response in resp_stream:
                     self._process_response(response)
                 return
             except grpc.RpcError as e:
                 reconnecting = self._can_reconnect(e)
                 if not reconnecting:
                     self._last_exception = e
                     return
                 self._reconnect_channel()
     except Exception as e:
         self._last_exception = e
     finally:
         logger.info("Shutting down data channel")
         self._shutdown()
Esempio n. 2
0
    def Datapath(self, request_iterator, context):
        client_id = _get_client_id_from_context(context)
        if client_id == "":
            return

        logger.info(f"New data connection from client {client_id}: ")
        modified_init_req, job_config = prepare_runtime_init_req(
            request_iterator)

        queue = Queue()
        queue.put(modified_init_req)
        if not self.proxy_manager.start_specific_server(client_id, job_config):
            logger.error(f"Server startup failed for client: {client_id}, "
                         f"using JobConfig: {job_config}!")
            context.set_code(grpc.StatusCode.ABORTED)
            return None

        channel = self.proxy_manager.get_channel(client_id)
        if channel is None:
            logger.error(f"Channel not found for {client_id}")
            context.set_code(grpc.StatusCode.NOT_FOUND)
            return None
        stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
        thread = Thread(target=forward_streaming_requests,
                        args=(request_iterator, queue),
                        daemon=True)
        thread.start()
        try:
            resp_stream = stub.Datapath(iter(queue.get, None),
                                        metadata=[("client_id", client_id)])
            for resp in resp_stream:
                yield resp
        finally:
            thread.join(1)
Esempio n. 3
0
 def _data_main(self) -> None:
     stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel)
     resp_stream = stub.Datapath(iter(self.request_queue.get, None),
                                 metadata=[("client_id", self._client_id)] +
                                 self._metadata,
                                 wait_for_ready=True)
     try:
         for response in resp_stream:
             if response.req_id == 0:
                 # This is not being waited for.
                 logger.debug(f"Got unawaited response {response}")
                 continue
             with self.cv:
                 self.ready_data[response.req_id] = response
                 self.cv.notify_all()
     except grpc.RpcError as e:
         with self.cv:
             self._in_shutdown = True
             self.cv.notify_all()
         if e.code() == grpc.StatusCode.CANCELLED:
             # Gracefully shutting down
             logger.info("Cancelling data channel")
         elif e.code() == grpc.StatusCode.UNAVAILABLE:
             # TODO(barakmich): The server may have
             # dropped. In theory, we can retry, as per
             # https://grpc.github.io/grpc/core/md_doc_statuscodes.html but
             # in practice we may need to think about the correct semantics
             # here.
             logger.info("Server disconnected from data channel")
         else:
             logger.error(
                 f"Got Error from data channel -- shutting down: {e}")
             raise e
Esempio n. 4
0
    def Datapath(self, request_iterator, context):
        client_id = _get_client_id_from_context(context)
        if client_id == "":
            return

        logger.info(f"New data connection from client {client_id}: ")
        modified_init_req, job_config = prepare_runtime_init_req(
            request_iterator)

        if not self.proxy_manager.start_specific_server(client_id, job_config):
            logger.error(f"Server startup failed for client: {client_id}, "
                         f"using JobConfig: {job_config}!")
            context.set_code(grpc.StatusCode.ABORTED)
            return None

        channel = self.proxy_manager.get_channel(client_id)
        if channel is None:
            logger.error(f"Channel not found for {client_id}")
            context.set_code(grpc.StatusCode.NOT_FOUND)
            return None
        stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
        try:
            with self.clients_lock:
                self.num_clients += 1
            new_iter = chain([modified_init_req], request_iterator)
            resp_stream = stub.Datapath(
                new_iter, metadata=[("client_id", client_id)])
            for resp in resp_stream:
                yield self.modify_connection_info_resp(resp)
        finally:
            with self.clients_lock:
                logger.debug(f"Client detached: {client_id}")
                self.num_clients -= 1
Esempio n. 5
0
    def Datapath(self, request_iterator, context):
        client_id = _get_client_id_from_context(context)
        if client_id == "":
            return
        logger.debug(f"New data connection from client {client_id}: ")

        init_req = next(request_iterator)
        init_type = init_req.WhichOneof("type")
        assert init_type == "init", ("Received initial message of type "
                                     f"{init_type}, not 'init'.")

        modified_init_req, job_config = prepare_runtime_init_req(init_req.init)
        init_req.init.CopyFrom(modified_init_req)
        queue = Queue()
        queue.put(init_req)

        self.proxy_manager.start_specific_server(client_id)

        channel = self.proxy_manager.get_channel(client_id)
        if channel is None:
            context.set_code(grpc.StatusCode.NOT_FOUND)
            return None
        stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
        thread = Thread(target=forward_streaming_requests,
                        args=(request_iterator, queue),
                        daemon=True)
        thread.start()

        resp_stream = stub.Datapath(iter(queue.get, None),
                                    metadata=[("client_id", client_id)])
        for resp in resp_stream:
            yield resp
Esempio n. 6
0
    def Datapath(self, request_iterator, context):
        client_id = _get_client_id_from_context(context)
        if client_id == "":
            return

        # Create Placeholder *before* reading the first request.
        server = self.proxy_manager.create_specific_server(client_id)
        try:
            with self.clients_lock:
                self.num_clients += 1

            logger.info(f"New data connection from client {client_id}: ")
            init_req = next(request_iterator)
            try:
                modified_init_req, job_config = prepare_runtime_init_req(
                    init_req)
                if not self.proxy_manager.start_specific_server(
                        client_id, job_config):
                    logger.error(
                        f"Server startup failed for client: {client_id}, "
                        f"using JobConfig: {job_config}!")
                    raise RuntimeError(
                        "Starting Ray client server failed. This is most "
                        "likely because the runtime_env failed to be "
                        "installed. See ray_client_server_[port].err on the "
                        "head node of the cluster for the relevant logs.")
                channel = self.proxy_manager.get_channel(client_id)
                if channel is None:
                    logger.error(f"Channel not found for {client_id}")
                    raise RuntimeError(
                        "Proxy failed to Connect to backend! Check "
                        "`ray_client_server.err` on the cluster.")
                stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
            except Exception:
                init_resp = ray_client_pb2.DataResponse(
                    init=ray_client_pb2.InitResponse(
                        ok=False, msg=traceback.format_exc()))
                init_resp.req_id = init_req.req_id
                yield init_resp
                return None

            new_iter = chain([modified_init_req], request_iterator)
            resp_stream = stub.Datapath(
                new_iter, metadata=[("client_id", client_id)])
            for resp in resp_stream:
                yield self.modify_connection_info_resp(resp)
        except Exception:
            logger.exception("Proxying Datapath failed!")
        finally:
            server.set_result(None)
            with self.clients_lock:
                logger.debug(f"Client detached: {client_id}")
                self.num_clients -= 1
Esempio n. 7
0
 def _data_main(self) -> None:
     stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel)
     resp_stream = stub.Datapath(iter(self.request_queue.get, None),
                                 metadata=(("client_id",
                                            self._client_id), ))
     for response in resp_stream:
         if response.req_id == 0:
             # This is not being waited for.
             logger.debug(f"Got unawaited response {response}")
             continue
         with self.cv:
             self.ready_data[response.req_id] = response
             self.cv.notify_all()
Esempio n. 8
0
 def _data_main(self) -> None:
     stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel)
     resp_stream = stub.Datapath(
         iter(self.request_queue.get, None),
         metadata=(("client_id", self._client_id), ))
     try:
         for response in resp_stream:
             if response.req_id == 0:
                 # This is not being waited for.
                 logger.debug(f"Got unawaited response {response}")
                 continue
             with self.cv:
                 self.ready_data[response.req_id] = response
                 self.cv.notify_all()
     except grpc.RpcError as e:
         if grpc.StatusCode.CANCELLED == e.code():
             # Gracefully shutting down
             logger.info("Cancelling data channel")
         else:
             logger.error(
                 f"Got Error from data channel -- shutting down: {e}")
             raise e
Esempio n. 9
0
 def set_channel(self, channel: grpc.Channel) -> None:
     self.stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
Esempio n. 10
0
    def Datapath(self, request_iterator, context):
        cleanup_requested = False
        start_time = time.time()
        client_id = _get_client_id_from_context(context)
        if client_id == "":
            return
        reconnecting = _get_reconnecting_from_context(context)

        if reconnecting:
            with self.clients_lock:
                if client_id not in self.clients_last_seen:
                    # Client took too long to reconnect, session has already
                    # been cleaned up
                    context.set_code(grpc.StatusCode.NOT_FOUND)
                    context.set_details(
                        "Attempted to reconnect a session that has already "
                        "been cleaned up")
                    return
                self.clients_last_seen[client_id] = start_time
            server = self.proxy_manager._get_server_for_client(client_id)
            channel = self.proxy_manager.get_channel(client_id)
            # iterator doesn't need modification on reconnect
            new_iter = request_iterator
        else:
            # Create Placeholder *before* reading the first request.
            server = self.proxy_manager.create_specific_server(client_id)
            with self.clients_lock:
                self.clients_last_seen[client_id] = start_time
                self.num_clients += 1

        try:
            if not reconnecting:
                logger.info(f"New data connection from client {client_id}: ")
                init_req = next(request_iterator)
                with self.clients_lock:
                    self.reconnect_grace_periods[client_id] = \
                        init_req.init.reconnect_grace_period
                try:
                    modified_init_req, job_config = prepare_runtime_init_req(
                        init_req)
                    if not self.proxy_manager.start_specific_server(
                            client_id, job_config):
                        logger.error(
                            f"Server startup failed for client: {client_id}, "
                            f"using JobConfig: {job_config}!")
                        raise RuntimeError(
                            "Starting Ray client server failed. See "
                            f"ray_client_server_{server.port}.err for "
                            "detailed logs.")
                    channel = self.proxy_manager.get_channel(client_id)
                    if channel is None:
                        logger.error(f"Channel not found for {client_id}")
                        raise RuntimeError(
                            "Proxy failed to Connect to backend! Check "
                            "`ray_client_server.err` and "
                            f"`ray_client_server_{server.port}.err` on the "
                            "head node of the cluster for the relevant logs. "
                            "By default these are located at "
                            "/tmp/ray/session_latest/logs.")
                except Exception:
                    init_resp = ray_client_pb2.DataResponse(
                        init=ray_client_pb2.InitResponse(
                            ok=False, msg=traceback.format_exc()))
                    init_resp.req_id = init_req.req_id
                    yield init_resp
                    return None

                new_iter = chain([modified_init_req], request_iterator)

            stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
            metadata = [("client_id", client_id), ("reconnecting",
                                                   str(reconnecting))]
            resp_stream = stub.Datapath(new_iter, metadata=metadata)
            for resp in resp_stream:
                resp_type = resp.WhichOneof("type")
                if resp_type == "connection_cleanup":
                    # Specific server is skipping cleanup, proxier should too
                    cleanup_requested = True
                yield self.modify_connection_info_resp(resp)
        except Exception as e:
            logger.exception("Proxying Datapath failed!")
            # Propogate error through context
            recoverable = _propagate_error_in_context(e, context)
            if not recoverable:
                # Client shouldn't attempt to recover, clean up connection
                cleanup_requested = True
        finally:
            cleanup_delay = self.reconnect_grace_periods.get(client_id)
            if not cleanup_requested and cleanup_delay is not None:
                # Delay cleanup, since client may attempt a reconnect
                # Wait on stopped event in case the server closes and we
                # can clean up earlier
                self.stopped.wait(timeout=cleanup_delay)
            with self.clients_lock:
                if client_id not in self.clients_last_seen:
                    logger.info(f"{client_id} not found. Skipping clean up.")
                    # Connection has already been cleaned up
                    return
                last_seen = self.clients_last_seen[client_id]
                logger.info(
                    f"{client_id} last started stream at {last_seen}. Current "
                    f"stream started at {start_time}.")
                if last_seen > start_time:
                    logger.info("Client reconnected. Skipping cleanup.")
                    # Client has reconnected, don't clean up
                    return
                logger.debug(f"Client detached: {client_id}")
                self.num_clients -= 1
                del self.clients_last_seen[client_id]
                if client_id in self.reconnect_grace_periods:
                    del self.reconnect_grace_periods[client_id]
                server.set_result(None)