Example #1
0
    def Logstream(self, request_iterator, context):
        client_id = _get_client_id_from_context(context)
        if client_id == "":
            return
        logger.debug(f"New logstream connection from client {client_id}: ")

        channel = None
        # We need to retry a few times because the LogClient *may* connect
        # Before the DataClient has finished connecting.
        for i in range(LOGSTREAM_RETRIES):
            channel = self.proxy_manager.get_channel(client_id)

            if channel is not None:
                break
            logger.warning(
                f"Retrying Logstream connection. {i+1} attempts failed.")
            time.sleep(LOGSTREAM_RETRY_INTERVAL_SEC)

        if channel is None:
            context.set_code(grpc.StatusCode.NOT_FOUND)
            context.set_details(
                "Logstream proxy failed to connect. Channel for client "
                f"{client_id} not found.")
            return None

        stub = ray_client_pb2_grpc.RayletLogStreamerStub(channel)

        resp_stream = stub.Logstream(
            request_iterator, metadata=[("client_id", client_id)])
        try:
            for resp in resp_stream:
                yield resp
        except Exception:
            logger.exception("Proxying Logstream failed!")
Example #2
0
    def _call_inner_function(
            self, request, context,
            method: str) -> Optional[ray_client_pb2_grpc.RayletDriverStub]:
        client_id = _get_client_id_from_context(context)
        chan = self.proxy_manager.get_channel(client_id)
        if not chan:
            logger.error(f"Channel for Client: {client_id} not found!")
            context.set_code(grpc.StatusCode.NOT_FOUND)
            return None

        stub = ray_client_pb2_grpc.RayletDriverStub(chan)
        try:
            metadata = [("client_id", client_id)]
            if context:
                metadata = context.invocation_metadata()
            return getattr(stub, method)(request, metadata=metadata)
        except Exception as e:
            # Error while proxying -- propagate the error's context to user
            logger.exception(f"Proxying call to {method} failed!")
            _propagate_error_in_context(e, context)
Example #3
0
    def Datapath(self, request_iterator, context):
        cleanup_requested = False
        start_time = time.time()
        client_id = _get_client_id_from_context(context)
        if client_id == "":
            return
        reconnecting = _get_reconnecting_from_context(context)

        if reconnecting:
            with self.clients_lock:
                if client_id not in self.clients_last_seen:
                    # Client took too long to reconnect, session has already
                    # been cleaned up
                    context.set_code(grpc.StatusCode.NOT_FOUND)
                    context.set_details(
                        "Attempted to reconnect a session that has already "
                        "been cleaned up")
                    return
                self.clients_last_seen[client_id] = start_time
            server = self.proxy_manager._get_server_for_client(client_id)
            channel = self.proxy_manager.get_channel(client_id)
            # iterator doesn't need modification on reconnect
            new_iter = request_iterator
        else:
            # Create Placeholder *before* reading the first request.
            server = self.proxy_manager.create_specific_server(client_id)
            with self.clients_lock:
                self.clients_last_seen[client_id] = start_time
                self.num_clients += 1

        try:
            if not reconnecting:
                logger.info(f"New data connection from client {client_id}: ")
                init_req = next(request_iterator)
                with self.clients_lock:
                    self.reconnect_grace_periods[client_id] = \
                        init_req.init.reconnect_grace_period
                try:
                    modified_init_req, job_config = prepare_runtime_init_req(
                        init_req)
                    if not self.proxy_manager.start_specific_server(
                            client_id, job_config):
                        logger.error(
                            f"Server startup failed for client: {client_id}, "
                            f"using JobConfig: {job_config}!")
                        raise RuntimeError(
                            "Starting Ray client server failed. See "
                            f"ray_client_server_{server.port}.err for "
                            "detailed logs.")
                    channel = self.proxy_manager.get_channel(client_id)
                    if channel is None:
                        logger.error(f"Channel not found for {client_id}")
                        raise RuntimeError(
                            "Proxy failed to Connect to backend! Check "
                            "`ray_client_server.err` and "
                            f"`ray_client_server_{server.port}.err` on the "
                            "head node of the cluster for the relevant logs. "
                            "By default these are located at "
                            "/tmp/ray/session_latest/logs.")
                except Exception:
                    init_resp = ray_client_pb2.DataResponse(
                        init=ray_client_pb2.InitResponse(
                            ok=False, msg=traceback.format_exc()))
                    init_resp.req_id = init_req.req_id
                    yield init_resp
                    return None

                new_iter = chain([modified_init_req], request_iterator)

            stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
            metadata = [("client_id", client_id), ("reconnecting",
                                                   str(reconnecting))]
            resp_stream = stub.Datapath(new_iter, metadata=metadata)
            for resp in resp_stream:
                resp_type = resp.WhichOneof("type")
                if resp_type == "connection_cleanup":
                    # Specific server is skipping cleanup, proxier should too
                    cleanup_requested = True
                yield self.modify_connection_info_resp(resp)
        except Exception as e:
            logger.exception("Proxying Datapath failed!")
            # Propogate error through context
            recoverable = _propagate_error_in_context(e, context)
            if not recoverable:
                # Client shouldn't attempt to recover, clean up connection
                cleanup_requested = True
        finally:
            cleanup_delay = self.reconnect_grace_periods.get(client_id)
            if not cleanup_requested and cleanup_delay is not None:
                # Delay cleanup, since client may attempt a reconnect
                # Wait on stopped event in case the server closes and we
                # can clean up earlier
                self.stopped.wait(timeout=cleanup_delay)
            with self.clients_lock:
                if client_id not in self.clients_last_seen:
                    logger.info(f"{client_id} not found. Skipping clean up.")
                    # Connection has already been cleaned up
                    return
                last_seen = self.clients_last_seen[client_id]
                logger.info(
                    f"{client_id} last started stream at {last_seen}. Current "
                    f"stream started at {start_time}.")
                if last_seen > start_time:
                    logger.info("Client reconnected. Skipping cleanup.")
                    # Client has reconnected, don't clean up
                    return
                logger.debug(f"Client detached: {client_id}")
                self.num_clients -= 1
                del self.clients_last_seen[client_id]
                if client_id in self.reconnect_grace_periods:
                    del self.reconnect_grace_periods[client_id]
                server.set_result(None)
Example #4
0
 def _has_channel_for_request(self, context):
     client_id = _get_client_id_from_context(context)
     return self.proxy_manager.has_channel(client_id)