Ejemplo n.º 1
0
 def __init__(self, gcs_channel: grpc.aio.Channel):
     self.register_gcs_client(gcs_channel)
     self._raylet_stubs = {}
     self._runtime_env_agent_stub = {}
     self._log_agent_stub = {}
     self._job_client = JobInfoStorageClient()
     self._id_id_map = IdToIpMap()
Ejemplo n.º 2
0
 def __init__(self, dashboard_head):
     super().__init__(dashboard_head)
     self._gcs_job_info_stub = None
     self._gcs_actor_info_stub = None
     self._dashboard_head = dashboard_head
     assert _internal_kv_initialized()
     self._job_info_client = JobInfoStorageClient()
     # For offloading CPU intensive work.
     self._thread_pool = concurrent.futures.ThreadPoolExecutor(
         max_workers=2, thread_name_prefix="api_head")
Ejemplo n.º 3
0
    def __init__(self, job_id: str, entrypoint: str, user_metadata: Dict[str, str]):
        self._job_id = job_id
        self._job_info_client = JobInfoStorageClient()
        self._log_client = JobLogStorageClient()
        self._runtime_env = ray.get_runtime_context().runtime_env
        self._entrypoint = entrypoint

        # Default metadata if not passed by the user.
        self._metadata = {JOB_ID_METADATA_KEY: job_id, JOB_NAME_METADATA_KEY: job_id}
        self._metadata.update(user_metadata)

        # fire and forget call from outer job manager to this actor
        self._stop_event = asyncio.Event()
Ejemplo n.º 4
0
 def __init__(self, gcs_channel: grpc.aio.Channel):
     self.register_gcs_client(gcs_channel)
     self._raylet_stubs = {}
     self._agent_stubs = {}
     self._job_client = JobInfoStorageClient()
Ejemplo n.º 5
0
class StateDataSourceClient:
    """The client to query states from various data sources such as Raylet, GCS, Agents.

    Note that it doesn't directly query core workers. They are proxied through raylets.

    The module is not in charge of service discovery. The caller is responsible for
    finding services and register stubs through `register*` APIs.

    Non `register*` APIs
    - throw a ValueError if it cannot find the source.
    - throw `StateSourceNetworkException` if there's any network errors.
    """
    def __init__(self, gcs_channel: grpc.aio.Channel):
        self.register_gcs_client(gcs_channel)
        self._raylet_stubs = {}
        self._agent_stubs = {}
        self._job_client = JobInfoStorageClient()

    def register_gcs_client(self, gcs_channel: grpc.aio.Channel):
        self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
            gcs_channel)
        self._gcs_pg_info_stub = gcs_service_pb2_grpc.PlacementGroupInfoGcsServiceStub(
            gcs_channel)
        self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
            gcs_channel)
        self._gcs_worker_info_stub = gcs_service_pb2_grpc.WorkerInfoGcsServiceStub(
            gcs_channel)

    def register_raylet_client(self, node_id: str, address: str, port: int):
        full_addr = f"{address}:{port}"
        options = ray_constants.GLOBAL_GRPC_OPTIONS
        channel = ray._private.utils.init_grpc_channel(full_addr,
                                                       options,
                                                       asynchronous=True)
        self._raylet_stubs[node_id] = NodeManagerServiceStub(channel)

    def unregister_raylet_client(self, node_id: str):
        self._raylet_stubs.pop(node_id)

    def register_agent_client(self, node_id, address: str, port: int):
        options = ray_constants.GLOBAL_GRPC_OPTIONS
        channel = ray._private.utils.init_grpc_channel(f"{address}:{port}",
                                                       options=options,
                                                       asynchronous=True)
        self._agent_stubs[node_id] = RuntimeEnvServiceStub(channel)

    def unregister_agent_client(self, node_id: str):
        self._agent_stubs.pop(node_id)

    def get_all_registered_raylet_ids(self) -> List[str]:
        return self._raylet_stubs.keys()

    def get_all_registered_agent_ids(self) -> List[str]:
        return self._agent_stubs.keys()

    @handle_network_errors
    async def get_all_actor_info(self,
                                 timeout: int = None) -> GetAllActorInfoReply:
        request = GetAllActorInfoRequest()
        reply = await self._gcs_actor_info_stub.GetAllActorInfo(
            request, timeout=timeout)
        return reply

    @handle_network_errors
    async def get_all_placement_group_info(self,
                                           timeout: int = None
                                           ) -> GetAllPlacementGroupReply:
        request = GetAllPlacementGroupRequest()
        reply = await self._gcs_pg_info_stub.GetAllPlacementGroup(
            request, timeout=timeout)
        return reply

    @handle_network_errors
    async def get_all_node_info(self,
                                timeout: int = None) -> GetAllNodeInfoReply:
        request = GetAllNodeInfoRequest()
        reply = await self._gcs_node_info_stub.GetAllNodeInfo(request,
                                                              timeout=timeout)
        return reply

    @handle_network_errors
    async def get_all_worker_info(self,
                                  timeout: int = None
                                  ) -> GetAllWorkerInfoReply:
        request = GetAllWorkerInfoRequest()
        reply = await self._gcs_worker_info_stub.GetAllWorkerInfo(
            request, timeout=timeout)
        return reply

    def get_job_info(self) -> Dict[str, JobInfo]:
        # Cannot use @handle_network_errors because async def is not supported yet.
        # TODO(sang): Support timeout & make it async
        try:
            return self._job_client.get_all_jobs()
        except Exception as e:
            raise StateSourceNetworkException(
                "Failed to query the job info.") from e

    @handle_network_errors
    async def get_task_info(self,
                            node_id: str,
                            timeout: int = None) -> GetTasksInfoReply:
        stub = self._raylet_stubs.get(node_id)
        if not stub:
            raise ValueError(f"Raylet for a node id, {node_id} doesn't exist.")

        reply = await stub.GetTasksInfo(GetTasksInfoRequest(), timeout=timeout)
        return reply

    @handle_network_errors
    async def get_object_info(self,
                              node_id: str,
                              timeout: int = None) -> GetNodeStatsReply:
        stub = self._raylet_stubs.get(node_id)
        if not stub:
            raise ValueError(f"Raylet for a node id, {node_id} doesn't exist.")

        reply = await stub.GetNodeStats(
            GetNodeStatsRequest(include_memory_info=True),
            timeout=timeout,
        )
        return reply

    @handle_network_errors
    async def get_runtime_envs_info(
            self,
            node_id: str,
            timeout: int = None) -> GetRuntimeEnvsInfoReply:
        stub = self._agent_stubs.get(node_id)
        if not stub:
            raise ValueError(f"Agent for a node id, {node_id} doesn't exist.")

        reply = await stub.GetRuntimeEnvsInfo(
            GetRuntimeEnvsInfoRequest(),
            timeout=timeout,
        )
        return reply
Ejemplo n.º 6
0
class APIHead(dashboard_utils.DashboardHeadModule):
    def __init__(self, dashboard_head):
        super().__init__(dashboard_head)
        self._gcs_job_info_stub = None
        self._gcs_actor_info_stub = None
        self._dashboard_head = dashboard_head
        assert _internal_kv_initialized()
        self._job_info_client = JobInfoStorageClient()
        # For offloading CPU intensive work.
        self._thread_pool = concurrent.futures.ThreadPoolExecutor(
            max_workers=2, thread_name_prefix="api_head")

    @routes.get("/api/actors/kill")
    async def kill_actor_gcs(self, req) -> aiohttp.web.Response:
        actor_id = req.query.get("actor_id")
        force_kill = req.query.get("force_kill", False) in ("true", "True")
        no_restart = req.query.get("no_restart", False) in ("true", "True")
        if not actor_id:
            return dashboard_optional_utils.rest_response(
                success=False, message="actor_id is required.")

        request = gcs_service_pb2.KillActorViaGcsRequest()
        request.actor_id = bytes.fromhex(actor_id)
        request.force_kill = force_kill
        request.no_restart = no_restart
        await self._gcs_actor_info_stub.KillActorViaGcs(request, timeout=5)

        message = (f"Force killed actor with id {actor_id}" if force_kill else
                   f"Requested actor with id {actor_id} to terminate. " +
                   "It will exit once running tasks complete")

        return dashboard_optional_utils.rest_response(success=True,
                                                      message=message)

    @routes.get("/api/snapshot")
    async def snapshot(self, req):
        (
            job_info,
            job_submission_data,
            actor_data,
            serve_data,
            session_name,
        ) = await asyncio.gather(
            self.get_job_info(),
            self.get_job_submission_info(),
            self.get_actor_info(),
            self.get_serve_info(),
            self.get_session_name(),
        )
        snapshot = {
            "jobs": job_info,
            "job_submission": job_submission_data,
            "actors": actor_data,
            "deployments": serve_data,
            "session_name": session_name,
            "ray_version": ray.__version__,
            "ray_commit": ray.__commit__,
        }
        return dashboard_optional_utils.rest_response(success=True,
                                                      message="hello",
                                                      snapshot=snapshot)

    @routes.get("/api/component_activities")
    async def get_component_activities(self, req) -> aiohttp.web.Response:
        # Get activity information for driver
        timeout = req.query.get("timeout", None)
        if timeout and timeout.isdigit():
            timeout = int(timeout)
        else:
            timeout = 5

        driver_activity_info = await self._get_job_activity_info(
            timeout=timeout)

        resp = {"driver": dataclasses.asdict(driver_activity_info)}
        return aiohttp.web.Response(
            text=json.dumps(resp),
            content_type="application/json",
            status=aiohttp.web.HTTPOk.status_code,
        )

    async def _get_job_activity_info(self,
                                     timeout: int) -> RayActivityResponse:
        # Returns if there is Ray activity from drivers (job).
        # Drivers in namespaces that start with _ray_internal_job_info_ are not
        # considered activity.
        request = gcs_service_pb2.GetAllJobInfoRequest()
        reply = await self._gcs_job_info_stub.GetAllJobInfo(request,
                                                            timeout=timeout)

        num_active_drivers = 0
        for job_table_entry in reply.job_info_list:
            is_dead = bool(job_table_entry.is_dead)
            in_internal_namespace = job_table_entry.config.ray_namespace.startswith(
                JobInfoStorageClient.JOB_DATA_KEY_PREFIX)
            if not is_dead and not in_internal_namespace:
                num_active_drivers += 1

        return RayActivityResponse(
            is_active=num_active_drivers > 0,
            reason=f"Number of active drivers: {num_active_drivers}"
            if num_active_drivers else None,
            timestamp=datetime.now().timestamp(),
        )

    def _get_job_info(self, metadata: Dict[str, str]) -> Optional[JobInfo]:
        # If a job submission ID has been added to a job, the status is
        # guaranteed to be returned.
        job_submission_id = metadata.get(JOB_ID_METADATA_KEY)
        return self._job_info_client.get_info(job_submission_id)

    async def get_job_info(self):
        """Return info for each job.  Here a job is a Ray driver."""
        request = gcs_service_pb2.GetAllJobInfoRequest()
        reply = await self._gcs_job_info_stub.GetAllJobInfo(request, timeout=5)

        jobs = {}
        for job_table_entry in reply.job_info_list:
            job_id = job_table_entry.job_id.hex()
            metadata = dict(job_table_entry.config.metadata)
            config = {
                "namespace":
                job_table_entry.config.ray_namespace,
                "metadata":
                metadata,
                "runtime_env":
                RuntimeEnv.deserialize(job_table_entry.config.runtime_env_info.
                                       serialized_runtime_env),
            }
            info = self._get_job_info(metadata)
            entry = {
                "status": None if info is None else info.status,
                "status_message": None if info is None else info.message,
                "is_dead": job_table_entry.is_dead,
                "start_time": job_table_entry.start_time,
                "end_time": job_table_entry.end_time,
                "config": config,
            }
            jobs[job_id] = entry

        return jobs

    async def get_job_submission_info(self):
        """Info for Ray job submission.  Here a job can have 0 or many drivers."""

        jobs = {}

        for job_submission_id, job_info in self._job_info_client.get_all_jobs(
        ).items():
            if job_info is not None:
                entry = {
                    "job_submission_id": job_submission_id,
                    "status": job_info.status,
                    "message": job_info.message,
                    "error_type": job_info.error_type,
                    "start_time": job_info.start_time,
                    "end_time": job_info.end_time,
                    "metadata": job_info.metadata,
                    "runtime_env": job_info.runtime_env,
                    "entrypoint": job_info.entrypoint,
                }
                jobs[job_submission_id] = entry
        return jobs

    async def get_actor_info(self):
        # TODO (Alex): GCS still needs to return actors from dead jobs.
        request = gcs_service_pb2.GetAllActorInfoRequest()
        request.show_dead_jobs = True
        reply = await self._gcs_actor_info_stub.GetAllActorInfo(request,
                                                                timeout=5)
        actors = {}
        for actor_table_entry in reply.actor_table_data:
            actor_id = actor_table_entry.actor_id.hex()
            runtime_env = json.loads(actor_table_entry.serialized_runtime_env)
            entry = {
                "job_id":
                actor_table_entry.job_id.hex(),
                "state":
                gcs_pb2.ActorTableData.ActorState.Name(
                    actor_table_entry.state),
                "name":
                actor_table_entry.name,
                "namespace":
                actor_table_entry.ray_namespace,
                "runtime_env":
                runtime_env,
                "start_time":
                actor_table_entry.start_time,
                "end_time":
                actor_table_entry.end_time,
                "is_detached":
                actor_table_entry.is_detached,
                "resources":
                dict(actor_table_entry.required_resources),
                "actor_class":
                actor_table_entry.class_name,
                "current_worker_id":
                actor_table_entry.address.worker_id.hex(),
                "current_raylet_id":
                actor_table_entry.address.raylet_id.hex(),
                "ip_address":
                actor_table_entry.address.ip_address,
                "port":
                actor_table_entry.address.port,
                "metadata":
                dict(),
            }
            actors[actor_id] = entry

            deployments = await self.get_serve_info()
            for _, deployment_info in deployments.items():
                for replica_actor_id, actor_info in deployment_info[
                        "actors"].items():
                    if replica_actor_id in actors:
                        serve_metadata = dict()
                        serve_metadata["replica_tag"] = actor_info[
                            "replica_tag"]
                        serve_metadata["deployment_name"] = deployment_info[
                            "name"]
                        serve_metadata["version"] = actor_info["version"]
                        actors[replica_actor_id]["metadata"][
                            "serve"] = serve_metadata
        return actors

    async def get_serve_info(self) -> Dict[str, Any]:
        # Conditionally import serve to prevent ModuleNotFoundError from serve
        # dependencies when only ray[default] is installed (#17712)
        try:
            from ray.serve.constants import SERVE_CONTROLLER_NAME
            from ray.serve.controller import SNAPSHOT_KEY as SERVE_SNAPSHOT_KEY
        except Exception:
            return {}

        # Serve wraps Ray's internal KV store and specially formats the keys.
        # These are the keys we are interested in:
        # SERVE_CONTROLLER_NAME(+ optional random letters):SERVE_SNAPSHOT_KEY
        # TODO: Convert to async GRPC, if CPU usage is not a concern.
        def get_deployments():
            serve_keys = _internal_kv_list(
                SERVE_CONTROLLER_NAME,
                namespace=ray_constants.KV_NAMESPACE_SERVE)
            serve_snapshot_keys = filter(
                lambda k: SERVE_SNAPSHOT_KEY in str(k), serve_keys)

            deployments_per_controller: List[Dict[str, Any]] = []
            for key in serve_snapshot_keys:
                val_bytes = _internal_kv_get(
                    key, namespace=ray_constants.KV_NAMESPACE_SERVE
                ) or "{}".encode("utf-8")
                deployments_per_controller.append(
                    json.loads(val_bytes.decode("utf-8")))
            # Merge the deployments dicts of all controllers.
            deployments: Dict[str, Any] = {
                k: v
                for d in deployments_per_controller for k, v in d.items()
            }
            # Replace the keys (deployment names) with their hashes to prevent
            # collisions caused by the automatic conversion to camelcase by the
            # dashboard agent.
            return {
                hashlib.sha1(name.encode()).hexdigest(): info
                for name, info in deployments.items()
            }

        return await asyncio.get_event_loop().run_in_executor(
            executor=self._thread_pool, func=get_deployments)

    async def get_session_name(self):
        # TODO(yic): Convert to async GRPC.
        def get_session():
            return ray.experimental.internal_kv._internal_kv_get(
                "session_name",
                namespace=ray_constants.KV_NAMESPACE_SESSION).decode()

        return await asyncio.get_event_loop().run_in_executor(
            executor=self._thread_pool, func=get_session)

    async def run(self, server):
        self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(
            self._dashboard_head.aiogrpc_gcs_channel)
        self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
            self._dashboard_head.aiogrpc_gcs_channel)

    @staticmethod
    def is_minimal_module():
        return False
Ejemplo n.º 7
0
class StateDataSourceClient:
    """The client to query states from various data sources such as Raylet, GCS, Agents.

    Note that it doesn't directly query core workers. They are proxied through raylets.

    The module is not in charge of service discovery. The caller is responsible for
    finding services and register stubs through `register*` APIs.

    Non `register*` APIs
    - Return the protobuf directly if it succeeds to query the source.
    - Raises an exception if there's any network issue.
    - throw a ValueError if it cannot find the source.
    """
    def __init__(self, gcs_channel: grpc.aio.Channel):
        self.register_gcs_client(gcs_channel)
        self._raylet_stubs = {}
        self._runtime_env_agent_stub = {}
        self._log_agent_stub = {}
        self._job_client = JobInfoStorageClient()
        self._id_id_map = IdToIpMap()

    def register_gcs_client(self, gcs_channel: grpc.aio.Channel):
        self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
            gcs_channel)
        self._gcs_pg_info_stub = gcs_service_pb2_grpc.PlacementGroupInfoGcsServiceStub(
            gcs_channel)
        self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
            gcs_channel)
        self._gcs_worker_info_stub = gcs_service_pb2_grpc.WorkerInfoGcsServiceStub(
            gcs_channel)

    def register_raylet_client(self, node_id: str, address: str, port: int):
        full_addr = f"{address}:{port}"
        options = ray_constants.GLOBAL_GRPC_OPTIONS
        channel = ray._private.utils.init_grpc_channel(full_addr,
                                                       options,
                                                       asynchronous=True)
        self._raylet_stubs[node_id] = NodeManagerServiceStub(channel)
        self._id_id_map.put(node_id, address)

    def unregister_raylet_client(self, node_id: str):
        self._raylet_stubs.pop(node_id)
        self._id_id_map.pop(node_id)

    def register_agent_client(self, node_id, address: str, port: int):
        options = ray_constants.GLOBAL_GRPC_OPTIONS
        channel = ray._private.utils.init_grpc_channel(f"{address}:{port}",
                                                       options=options,
                                                       asynchronous=True)
        self._runtime_env_agent_stub[node_id] = RuntimeEnvServiceStub(channel)
        self._log_agent_stub[node_id] = LogServiceStub(channel)
        self._id_id_map.put(node_id, address)

    def unregister_agent_client(self, node_id: str):
        self._runtime_env_agent_stub.pop(node_id)
        self._log_agent_stub.pop(node_id)
        self._id_id_map.pop(node_id)

    def get_all_registered_raylet_ids(self) -> List[str]:
        return self._raylet_stubs.keys()

    def get_all_registered_agent_ids(self) -> List[str]:
        assert len(self._log_agent_stub) == len(self._runtime_env_agent_stub)
        return self._runtime_env_agent_stub.keys()

    def ip_to_node_id(self, ip: Optional[str]) -> Optional[str]:
        """Return the node id that corresponds to the given ip.

        Args:
            ip: The ip address.

        Returns:
            None if the corresponding id doesn't exist.
            Node id otherwise. If None node_ip is given,
            it will also return None.
        """
        if not ip:
            return None
        return self._id_id_map.get_node_id(ip)

    @handle_grpc_network_errors
    async def get_all_actor_info(
            self,
            timeout: int = None,
            limit: int = None) -> Optional[GetAllActorInfoReply]:
        if not limit:
            limit = MAX_LIMIT

        request = GetAllActorInfoRequest(limit=limit)
        reply = await self._gcs_actor_info_stub.GetAllActorInfo(
            request, timeout=timeout)
        return reply

    @handle_grpc_network_errors
    async def get_all_placement_group_info(
            self,
            timeout: int = None,
            limit: int = None) -> Optional[GetAllPlacementGroupReply]:
        if not limit:
            limit = MAX_LIMIT

        request = GetAllPlacementGroupRequest(limit=limit)
        reply = await self._gcs_pg_info_stub.GetAllPlacementGroup(
            request, timeout=timeout)
        return reply

    @handle_grpc_network_errors
    async def get_all_node_info(self,
                                timeout: int = None
                                ) -> Optional[GetAllNodeInfoReply]:
        request = GetAllNodeInfoRequest()
        reply = await self._gcs_node_info_stub.GetAllNodeInfo(request,
                                                              timeout=timeout)
        return reply

    @handle_grpc_network_errors
    async def get_all_worker_info(
            self,
            timeout: int = None,
            limit: int = None) -> Optional[GetAllWorkerInfoReply]:
        if not limit:
            limit = MAX_LIMIT

        request = GetAllWorkerInfoRequest(limit=limit)
        reply = await self._gcs_worker_info_stub.GetAllWorkerInfo(
            request, timeout=timeout)
        return reply

    def get_job_info(self) -> Optional[Dict[str, JobInfo]]:
        # Cannot use @handle_grpc_network_errors because async def is not supported yet.
        # TODO(sang): Support timeout & make it async
        try:
            return self._job_client.get_all_jobs()
        except grpc.aio.AioRpcError as e:
            if (e.code == grpc.StatusCode.DEADLINE_EXCEEDED
                    or e.code == grpc.StatusCode.UNAVAILABLE):
                raise DataSourceUnavailable(
                    "Failed to query the data source. "
                    "It is either there's a network issue, or the source is down."
                )
            else:
                logger.exception(e)
                raise e

    @handle_grpc_network_errors
    async def get_task_info(self,
                            node_id: str,
                            timeout: int = None,
                            limit: int = None) -> Optional[GetTasksInfoReply]:
        if not limit:
            limit = MAX_LIMIT

        stub = self._raylet_stubs.get(node_id)
        if not stub:
            raise ValueError(f"Raylet for a node id, {node_id} doesn't exist.")

        reply = await stub.GetTasksInfo(GetTasksInfoRequest(limit=limit),
                                        timeout=timeout)
        return reply

    @handle_grpc_network_errors
    async def get_object_info(
            self,
            node_id: str,
            timeout: int = None,
            limit: int = None) -> Optional[GetObjectsInfoReply]:
        if not limit:
            limit = MAX_LIMIT

        stub = self._raylet_stubs.get(node_id)
        if not stub:
            raise ValueError(f"Raylet for a node id, {node_id} doesn't exist.")

        reply = await stub.GetObjectsInfo(
            GetObjectsInfoRequest(limit=limit),
            timeout=timeout,
        )
        return reply

    @handle_grpc_network_errors
    async def get_runtime_envs_info(
            self,
            node_id: str,
            timeout: int = None,
            limit: int = None) -> Optional[GetRuntimeEnvsInfoReply]:
        if not limit:
            limit = MAX_LIMIT

        stub = self._runtime_env_agent_stub.get(node_id)
        if not stub:
            raise ValueError(f"Agent for a node id, {node_id} doesn't exist.")

        reply = await stub.GetRuntimeEnvsInfo(
            GetRuntimeEnvsInfoRequest(limit=limit),
            timeout=timeout,
        )
        return reply

    @handle_grpc_network_errors
    async def list_logs(self,
                        node_id: str,
                        glob_filter: str,
                        timeout: int = None) -> ListLogsReply:
        stub = self._log_agent_stub.get(node_id)
        if not stub:
            raise ValueError(f"Agent for node id: {node_id} doesn't exist.")
        return await stub.ListLogs(ListLogsRequest(glob_filter=glob_filter),
                                   timeout=timeout)

    @handle_grpc_network_errors
    async def stream_log(
        self,
        node_id: str,
        log_file_name: str,
        keep_alive: bool,
        lines: int,
        interval: Optional[float],
        timeout: int,
    ) -> UnaryStreamCall:
        stub = self._log_agent_stub.get(node_id)
        if not stub:
            raise ValueError(f"Agent for node id: {node_id} doesn't exist.")
        stream = stub.StreamLog(
            StreamLogRequest(
                keep_alive=keep_alive,
                log_file_name=log_file_name,
                lines=lines,
                interval=interval,
            ),
            timeout=timeout,
        )
        await self._validate_stream(stream)
        return stream

    @staticmethod
    async def _validate_stream(stream):
        metadata = await stream.initial_metadata()
        if metadata.get(
                log_consts.LOG_GRPC_ERROR) == log_consts.FILE_NOT_FOUND:
            raise ValueError(
                'File "{log_file_name}" not found on node {node_id}')
Ejemplo n.º 8
0
class JobSupervisor:
    """
    Ray actor created by JobManager for each submitted job, responsible to
    setup runtime_env, execute given shell command in subprocess, update job
    status, persist job logs and manage subprocess group cleaning.

    One job supervisor actor maps to one subprocess, for one job_id.
    Job supervisor actor should fate share with subprocess it created.
    """

    SUBPROCESS_POLL_PERIOD_S = 0.1

    def __init__(self, job_id: str, entrypoint: str, user_metadata: Dict[str,
                                                                         str]):
        self._job_id = job_id
        self._job_info_client = JobInfoStorageClient()
        self._log_client = JobLogStorageClient()
        self._runtime_env = ray.get_runtime_context().runtime_env
        self._entrypoint = entrypoint

        # Default metadata if not passed by the user.
        self._metadata = {
            JOB_ID_METADATA_KEY: job_id,
            JOB_NAME_METADATA_KEY: job_id
        }
        self._metadata.update(user_metadata)

        # fire and forget call from outer job manager to this actor
        self._stop_event = asyncio.Event()

    def ping(self):
        """Used to check the health of the actor."""
        pass

    def _exec_entrypoint(self, logs_path: str) -> subprocess.Popen:
        """
        Runs the entrypoint command as a child process, streaming stderr &
        stdout to given log files.

        Meanwhile we start a demon process and group driver
        subprocess in same pgid, such that if job actor dies, entire process
        group also fate share with it.

        Args:
            logs_path: File path on head node's local disk to store driver
                command's stdout & stderr.
        Returns:
            child_process: Child process that runs the driver command. Can be
                terminated or killed upon user calling stop().
        """
        with open(logs_path, "w") as logs_file:
            child_process = subprocess.Popen(
                self._entrypoint,
                shell=True,
                start_new_session=True,
                stdout=logs_file,
                stderr=subprocess.STDOUT,
            )
            parent_pid = os.getpid()
            # Create new pgid with new subprocess to execute driver command
            child_pid = child_process.pid
            child_pgid = os.getpgid(child_pid)

            # Open a new subprocess to kill the child process when the parent
            # process dies kill -s 0 parent_pid will succeed if the parent is
            # alive. If it fails, SIGKILL the child process group and exit
            subprocess.Popen(
                f"while kill -s 0 {parent_pid}; do sleep 1; done; kill -9 -{child_pgid}",  # noqa: E501
                shell=True,
                # Suppress output
                stdout=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL,
            )
            return child_process

    async def _polling(self, child_process) -> int:
        try:
            while child_process is not None:
                return_code = child_process.poll()
                if return_code is not None:
                    # subprocess finished with return code
                    return return_code
                else:
                    # still running, yield control, 0.1s by default
                    await asyncio.sleep(self.SUBPROCESS_POLL_PERIOD_S)
        except Exception:
            if child_process:
                # TODO (jiaodong): Improve this with SIGTERM then SIGKILL
                child_process.kill()
            return 1

    async def run(
        self,
        # Signal actor used in testing to capture PENDING -> RUNNING cases
        _start_signal_actor: Optional[ActorHandle] = None,
    ):
        """
        Stop and start both happen asynchrously, coordinated by asyncio event
        and coroutine, respectively.

        1) Sets job status as running
        2) Pass runtime env and metadata to subprocess as serialized env
            variables.
        3) Handle concurrent events of driver execution and
        """
        curr_status = self._job_info_client.get_status(self._job_id)
        assert curr_status == JobStatus.PENDING, "Run should only be called once."

        if _start_signal_actor:
            # Block in PENDING state until start signal received.
            await _start_signal_actor.wait.remote()

        self._job_info_client.put_status(self._job_id, JobStatus.RUNNING)

        try:
            # Set JobConfig for the child process (runtime_env, metadata).
            os.environ[RAY_JOB_CONFIG_JSON_ENV_VAR] = json.dumps({
                "runtime_env":
                self._runtime_env,
                "metadata":
                self._metadata,
            })
            # Always set RAY_ADDRESS as find_bootstrap_address address for
            # job submission. In case of local development, prevent user from
            # re-using http://{address}:{dashboard_port} to interact with
            # jobs SDK.
            # TODO:(mwtian) Check why "auto" does not work in entrypoint script
            os.environ[
                ray_constants.
                RAY_ADDRESS_ENVIRONMENT_VARIABLE] = ray._private.services.find_bootstrap_address(
                ).pop()

            # Set PYTHONUNBUFFERED=1 to stream logs during the job instead of
            # only streaming them upon completion of the job.
            os.environ["PYTHONUNBUFFERED"] = "1"
            logger.info(
                "Submitting job with RAY_ADDRESS = "
                f"{os.environ[ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE]}"
            )
            log_path = self._log_client.get_log_file_path(self._job_id)
            child_process = self._exec_entrypoint(log_path)

            polling_task = create_task(self._polling(child_process))
            finished, _ = await asyncio.wait(
                [polling_task, self._stop_event.wait()],
                return_when=FIRST_COMPLETED)

            if self._stop_event.is_set():
                polling_task.cancel()
                # TODO (jiaodong): Improve this with SIGTERM then SIGKILL
                child_process.kill()
                self._job_info_client.put_status(self._job_id,
                                                 JobStatus.STOPPED)
            else:
                # Child process finished execution and no stop event is set
                # at the same time
                assert len(
                    finished) == 1, "Should have only one coroutine done"
                [child_process_task] = finished
                return_code = child_process_task.result()
                if return_code == 0:
                    self._job_info_client.put_status(self._job_id,
                                                     JobStatus.SUCCEEDED)
                else:
                    log_tail = self._log_client.get_last_n_log_lines(
                        self._job_id)
                    if log_tail is not None and log_tail != "":
                        message = ("Job failed due to an application error, "
                                   "last available logs:\n" + log_tail)
                    else:
                        message = None
                    self._job_info_client.put_status(self._job_id,
                                                     JobStatus.FAILED,
                                                     message=message)
        except Exception:
            logger.error(
                "Got unexpected exception while trying to execute driver "
                f"command. {traceback.format_exc()}")
        finally:
            # clean up actor after tasks are finished
            ray.actor.exit_actor()

    def stop(self):
        """Set step_event and let run() handle the rest in its asyncio.wait()."""
        self._stop_event.set()
Ejemplo n.º 9
0
    def __init__(self):
        self._job_info_client = JobInfoStorageClient()
        self._log_client = JobLogStorageClient()
        self._supervisor_actor_cls = ray.remote(JobSupervisor)

        self._recover_running_jobs()
Ejemplo n.º 10
0
class JobManager:
    """Provide python APIs for job submission and management.

    It does not provide persistence, all info will be lost if the cluster
    goes down.
    """

    JOB_ACTOR_NAME = "_ray_internal_job_actor_{job_id}"
    # Time that we will sleep while tailing logs if no new log line is
    # available.
    LOG_TAIL_SLEEP_S = 1
    JOB_MONITOR_LOOP_PERIOD_S = 1

    def __init__(self):
        self._job_info_client = JobInfoStorageClient()
        self._log_client = JobLogStorageClient()
        self._supervisor_actor_cls = ray.remote(JobSupervisor)

        self._recover_running_jobs()

    def _recover_running_jobs(self):
        """Recovers all running jobs from the status client.

        For each job, we will spawn a coroutine to monitor it.
        Each will be added to self._running_jobs and reconciled.
        """
        all_jobs = self._job_info_client.get_all_jobs()
        for job_id, job_info in all_jobs.items():
            if not job_info.status.is_terminal():
                create_task(self._monitor_job(job_id))

    def _get_actor_for_job(self, job_id: str) -> Optional[ActorHandle]:
        try:
            return ray.get_actor(self.JOB_ACTOR_NAME.format(job_id=job_id))
        except ValueError:  # Ray returns ValueError for nonexistent actor.
            return None

    async def _monitor_job(self,
                           job_id: str,
                           job_supervisor: Optional[ActorHandle] = None):
        """Monitors the specified job until it enters a terminal state.

        This is necessary because we need to handle the case where the
        JobSupervisor dies unexpectedly.
        """
        is_alive = True
        if job_supervisor is None:
            job_supervisor = self._get_actor_for_job(job_id)

            if job_supervisor is None:
                logger.error(f"Failed to get job supervisor for job {job_id}.")
                self._job_info_client.put_status(
                    job_id,
                    JobStatus.FAILED,
                    message=
                    "Unexpected error occurred: Failed to get job supervisor.",
                )
                is_alive = False

        while is_alive:
            try:
                await job_supervisor.ping.remote()
                await asyncio.sleep(self.JOB_MONITOR_LOOP_PERIOD_S)
            except Exception as e:
                is_alive = False
                if self._job_info_client.get_status(job_id).is_terminal():
                    # If the job is already in a terminal state, then the actor
                    # exiting is expected.
                    pass
                elif isinstance(e, RuntimeEnvSetupError):
                    logger.info(
                        f"Failed to set up runtime_env for job {job_id}.")
                    self._job_info_client.put_status(
                        job_id,
                        JobStatus.FAILED,
                        message=f"runtime_env setup failed: {e}",
                    )
                else:
                    logger.warning(
                        f"Job supervisor for job {job_id} failed unexpectedly: {e}."
                    )
                    self._job_info_client.put_status(
                        job_id,
                        JobStatus.FAILED,
                        message=f"Unexpected error occurred: {e}",
                    )

        # Kill the actor defensively to avoid leaking actors in unexpected error cases.
        if job_supervisor is not None:
            ray.kill(job_supervisor, no_restart=True)

    def _get_current_node_resource_key(self) -> str:
        """Get the Ray resource key for current node.

        It can be used for actor placement.
        """
        current_node_id = ray.get_runtime_context().node_id.hex()
        for node in ray.nodes():
            if node["NodeID"] == current_node_id:
                # Found the node.
                for key in node["Resources"].keys():
                    if key.startswith("node:"):
                        return key
        else:
            raise ValueError(
                "Cannot find the node dictionary for current node.")

    def _handle_supervisor_startup(self, job_id: str,
                                   result: Optional[Exception]):
        """Handle the result of starting a job supervisor actor.

        If started successfully, result should be None. Otherwise it should be
        an Exception.

        On failure, the job will be marked failed with a relevant error
        message.
        """
        if result is None:
            return

    def submit_job(
        self,
        *,
        entrypoint: str,
        job_id: Optional[str] = None,
        runtime_env: Optional[Dict[str, Any]] = None,
        metadata: Optional[Dict[str, str]] = None,
        _start_signal_actor: Optional[ActorHandle] = None,
    ) -> str:
        """
        Job execution happens asynchronously.

        1) Generate a new unique id for this job submission, each call of this
            method assumes they're independent submission with its own new
            ID, job supervisor actor, and child process.
        2) Create new detached actor with same runtime_env as job spec

        Actual setting up runtime_env, subprocess group, driver command
        execution, subprocess cleaning up and running status update to GCS
        is all handled by job supervisor actor.

        Args:
            entrypoint: Driver command to execute in subprocess shell.
                Represents the entrypoint to start user application.
            runtime_env: Runtime environment used to execute driver command,
                which could contain its own ray.init() to configure runtime
                env at ray cluster, task and actor level.
            metadata: Support passing arbitrary data to driver command in
                case needed.
            _start_signal_actor: Used in testing only to capture state
                transitions between PENDING -> RUNNING. Regular user shouldn't
                need this.

        Returns:
            job_id: Generated uuid for further job management. Only valid
                within the same ray cluster.
        """
        if job_id is None:
            job_id = generate_job_id()
        elif self._job_info_client.get_status(job_id) is not None:
            raise RuntimeError(f"Job {job_id} already exists.")

        logger.info(f"Starting job with job_id: {job_id}")
        job_info = JobInfo(
            entrypoint=entrypoint,
            status=JobStatus.PENDING,
            start_time=int(time.time() * 1000),
            metadata=metadata,
            runtime_env=runtime_env,
        )
        self._job_info_client.put_info(job_id, job_info)

        # Wait for the actor to start up asynchronously so this call always
        # returns immediately and we can catch errors with the actor starting
        # up.
        try:
            supervisor = self._supervisor_actor_cls.options(
                lifetime="detached",
                name=self.JOB_ACTOR_NAME.format(job_id=job_id),
                num_cpus=0,
                # Currently we assume JobManager is created by dashboard server
                # running on headnode, same for job supervisor actors scheduled
                resources={
                    self._get_current_node_resource_key(): 0.001,
                },
                runtime_env=runtime_env,
            ).remote(job_id, entrypoint, metadata or {})
            supervisor.run.remote(_start_signal_actor=_start_signal_actor)

            # Monitor the job in the background so we can detect errors without
            # requiring a client to poll.
            create_task(self._monitor_job(job_id, job_supervisor=supervisor))
        except Exception as e:
            self._job_info_client.put_status(
                job_id,
                JobStatus.FAILED,
                message=f"Failed to start job supervisor: {e}.",
            )

        return job_id

    def stop_job(self, job_id) -> bool:
        """Request a job to exit, fire and forget.

        Returns whether or not the job was running.
        """
        job_supervisor_actor = self._get_actor_for_job(job_id)
        if job_supervisor_actor is not None:
            # Actor is still alive, signal it to stop the driver, fire and
            # forget
            job_supervisor_actor.stop.remote()
            return True
        else:
            return False

    def get_job_status(self, job_id: str) -> Optional[JobStatus]:
        """Get latest status of a job."""
        return self._job_info_client.get_status(job_id)

    def get_job_info(self, job_id: str) -> Optional[JobInfo]:
        """Get latest info of a job."""
        return self._job_info_client.get_info(job_id)

    def list_jobs(self) -> Dict[str, JobInfo]:
        """Get info for all jobs."""
        return self._job_info_client.get_all_jobs()

    def get_job_logs(self, job_id: str) -> str:
        """Get all logs produced by a job."""
        return self._log_client.get_logs(job_id)

    async def tail_job_logs(self, job_id: str) -> Iterator[str]:
        """Return an iterator following the logs of a job."""
        if self.get_job_status(job_id) is None:
            raise RuntimeError(f"Job '{job_id}' does not exist.")

        for line in self._log_client.tail_logs(job_id):
            if line is None:
                # Return if the job has exited and there are no new log lines.
                status = self.get_job_status(job_id)
                if status not in {JobStatus.PENDING, JobStatus.RUNNING}:
                    return

                await asyncio.sleep(self.LOG_TAIL_SLEEP_S)
            else:
                yield line