예제 #1
0
    async def list_placement_groups(self, *, option: ListApiOptions) -> dict:
        """List all placement group information from the cluster.

        Returns:
            {pg_id -> pg_data_in_dict}
            pg_data_in_dict's schema is in PlacementGroupState
        """
        reply = await self._client.get_all_placement_group_info(
            timeout=option.timeout)
        result = []
        for message in reply.placement_group_table_data:

            data = self._message_to_dict(
                message=message,
                fields_to_decode=["placement_group_id"],
            )
            data = filter_fields(data, PlacementGroupState)
            result.append(data)

        # Sort to make the output deterministic.
        result.sort(key=lambda entry: entry["placement_group_id"])
        return {
            d["placement_group_id"]: d
            for d in islice(result, option.limit)
        }
예제 #2
0
    async def list_tasks(self, *, option: ListApiOptions) -> dict:
        """List all task information from the cluster.

        Returns:
            {task_id -> task_data_in_dict}
            task_data_in_dict's schema is in TaskState
        """
        replies = await asyncio.gather(*[
            self._client.get_task_info(node_id, timeout=option.timeout)
            for node_id in self._client.get_all_registered_raylet_ids()
        ])

        result = []
        for reply in replies:
            tasks = reply.task_info_entries
            for task in tasks:
                data = self._message_to_dict(
                    message=task,
                    fields_to_decode=["task_id"],
                )
                data = filter_fields(data, TaskState)
                result.append(data)

        # Sort to make the output deterministic.
        result.sort(key=lambda entry: entry["task_id"])
        return {d["task_id"]: d for d in islice(result, option.limit)}
예제 #3
0
    async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse:
        """List all worker information from the cluster.

        Returns:
            {worker_id -> worker_data_in_dict}
            worker_data_in_dict's schema is in WorkerState
        """
        try:
            reply = await self._client.get_all_worker_info(timeout=option.timeout)
        except DataSourceUnavailable:
            raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)

        result = []
        for message in reply.worker_table_data:
            data = self._message_to_dict(
                message=message, fields_to_decode=["worker_id"]
            )
            data["worker_id"] = data["worker_address"]["worker_id"]
            data = filter_fields(data, WorkerState)
            result.append(data)

        # Sort to make the output deterministic.
        result.sort(key=lambda entry: entry["worker_id"])
        return ListApiResponse(
            result={d["worker_id"]: d for d in islice(result, option.limit)}
        )
예제 #4
0
    def _filter(
        self,
        data: List[dict],
        filters: List[Tuple[str, SupportedFilterType]],
        state_dataclass: StateSchema,
        detail: bool,
    ) -> List[dict]:
        """Return the filtered data given filters.

        Args:
            data: A list of state data.
            filters: A list of KV tuple to filter data (key, val). The data is filtered
                if data[key] != val.
            state_dataclass: The state schema.

        Returns:
            A list of filtered state data in dictionary. Each state data's
            unnecessary columns are filtered by the given state_dataclass schema.
        """
        filters = _convert_filters_type(filters, state_dataclass)
        result = []
        for datum in data:
            match = True
            for filter_column, filter_predicate, filter_value in filters:
                filterable_columns = state_dataclass.filterable_columns()
                filter_column = filter_column.lower()
                if filter_column not in filterable_columns:
                    raise ValueError(
                        f"The given filter column {filter_column} is not supported. "
                        f"Supported filter columns: {filterable_columns}")

                if filter_predicate == "=":
                    match = datum[filter_column] == filter_value
                elif filter_predicate == "!=":
                    match = datum[filter_column] != filter_value
                else:
                    raise ValueError(
                        f"Unsupported filter predicate {filter_predicate} is given. "
                        "Available predicates: =, !=.")

                if not match:
                    break

            if match:
                result.append(filter_fields(datum, state_dataclass, detail))
        return result
예제 #5
0
    async def list_nodes(self, *, option: ListApiOptions) -> dict:
        """List all node information from the cluster.

        Returns:
            {node_id -> node_data_in_dict}
            node_data_in_dict's schema is in NodeState
        """
        reply = await self._client.get_all_node_info(timeout=option.timeout)
        result = []
        for message in reply.node_info_list:
            data = self._message_to_dict(message=message,
                                         fields_to_decode=["node_id"])
            data = filter_fields(data, NodeState)
            result.append(data)

        # Sort to make the output deterministic.
        result.sort(key=lambda entry: entry["node_id"])
        return {d["node_id"]: d for d in islice(result, option.limit)}
예제 #6
0
    async def list_runtime_envs(self, *, option: ListApiOptions) -> List[dict]:
        """List all runtime env information from the cluster.

        Returns:
            A list of runtime env information in the cluster.
            The schema of returned "dict" is equivalent to the
            `RuntimeEnvState` protobuf message.
            We don't have id -> data mapping like other API because runtime env
            doesn't have unique ids.
        """
        replies = await asyncio.gather(*[
            self._client.get_runtime_envs_info(node_id, timeout=option.timeout)
            for node_id in self._client.get_all_registered_agent_ids()
        ])
        result = []
        for node_id, reply in zip(self._client.get_all_registered_agent_ids(),
                                  replies):
            states = reply.runtime_env_states
            for state in states:
                data = self._message_to_dict(message=state,
                                             fields_to_decode=[])
                # Need to deseiralize this field.
                data["runtime_env"] = RuntimeEnv.deserialize(
                    data["runtime_env"]).to_dict()
                data["node_id"] = node_id
                data = filter_fields(data, RuntimeEnvState)
                result.append(data)

        # Sort to make the output deterministic.
        def sort_func(entry):
            # If creation time is not there yet (runtime env is failed
            # to be created or not created yet, they are the highest priority.
            # Otherwise, "bigger" creation time is coming first.
            if "creation_time_ms" not in entry:
                return float("inf")
            elif entry["creation_time_ms"] is None:
                return float("inf")
            else:
                return float(entry["creation_time_ms"])

        result.sort(key=sort_func, reverse=True)
        return list(islice(result, option.limit))
예제 #7
0
    async def list_objects(self, *, option: ListApiOptions) -> dict:
        """List all object information from the cluster.

        Returns:
            {object_id -> object_data_in_dict}
            object_data_in_dict's schema is in ObjectState
        """
        replies = await asyncio.gather(*[
            self._client.get_object_info(node_id, timeout=option.timeout)
            for node_id in self._client.get_all_registered_raylet_ids()
        ])

        worker_stats = []
        for reply in replies:
            for core_worker_stat in reply.core_workers_stats:
                # NOTE: Set preserving_proto_field_name=False here because
                # `construct_memory_table` requires a dictionary that has
                # modified protobuf name
                # (e.g., workerId instead of worker_id) as a key.
                worker_stats.append(
                    self._message_to_dict(
                        message=core_worker_stat,
                        fields_to_decode=["object_id"],
                        preserving_proto_field_name=False,
                    ))

        result = []
        memory_table = memory_utils.construct_memory_table(worker_stats)
        for entry in memory_table.table:
            data = entry.as_dict()
            # `construct_memory_table` returns object_ref field which is indeed
            # object_id. We do transformation here.
            # TODO(sang): Refactor `construct_memory_table`.
            data["object_id"] = data["object_ref"]
            del data["object_ref"]
            data = filter_fields(data, ObjectState)
            result.append(data)

        # Sort to make the output deterministic.
        result.sort(key=lambda entry: entry["object_id"])
        return {d["object_id"]: d for d in islice(result, option.limit)}
예제 #8
0
    async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse:
        """List all node information from the cluster.

        Returns:
            {node_id -> node_data_in_dict}
            node_data_in_dict's schema is in NodeState
        """
        try:
            reply = await self._client.get_all_node_info(timeout=option.timeout)
        except DataSourceUnavailable:
            raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)

        result = []
        for message in reply.node_info_list:
            data = self._message_to_dict(message=message, fields_to_decode=["node_id"])
            data["node_ip"] = data["node_manager_address"]
            data = filter_fields(data, NodeState)
            result.append(data)

        result = self._filter(result, option.filters, NodeState)
        # Sort to make the output deterministic.
        result.sort(key=lambda entry: entry["node_id"])
        return ListApiResponse(result=list(islice(result, option.limit)))
예제 #9
0
    def _filter(
        self,
        data: List[dict],
        filters: List[Tuple[str, SupportedFilterType]],
        state_dataclass: StateSchema,
    ) -> List[dict]:
        """Return the filtered data given filters.

        Args:
            data: A list of state data.
            filters: A list of KV tuple to filter data (key, val). The data is filtered
                if data[key] != val.
            state_dataclass: The state schema.

        Returns:
            A list of filtered state data in dictionary. Each state data's
            unncessary columns are filtered by the given state_dataclass schema.
        """
        filters = _convert_filters_type(filters, state_dataclass)
        result = []
        for datum in data:
            match = True
            for filter_column, filter_value in filters:
                filterable_columns = state_dataclass.filterable_columns()
                if filter_column not in filterable_columns:
                    raise ValueError(
                        f"The given filter column {filter_column} is not supported. "
                        f"Supported filter columns: {filterable_columns}"
                    )

                if datum[filter_column] != filter_value:
                    match = False
                    break

            if match:
                result.append(filter_fields(datum, state_dataclass))
        return result
예제 #10
0
    async def list_tasks(self, *, option: ListApiOptions) -> dict:
        """List all task information from the cluster.

        Returns:
            {task_id -> task_data_in_dict}
            task_data_in_dict's schema is in TaskState
        """
        replies = await asyncio.gather(*[
            self._client.get_task_info(node_id, timeout=option.timeout)
            for node_id in self._client.get_all_registered_raylet_ids()
        ])

        running_task_id = set()
        for reply in replies:
            for task_id in reply.running_task_ids:
                running_task_id.add(binary_to_hex(task_id))

        result = []
        for reply in replies:
            logger.info(reply)
            tasks = reply.owned_task_info_entries
            for task in tasks:
                data = self._message_to_dict(
                    message=task,
                    fields_to_decode=["task_id"],
                )
                if data["task_id"] in running_task_id:
                    data[
                        "scheduling_state"] = TaskStatus.DESCRIPTOR.values_by_number[
                            TaskStatus.RUNNING].name
                data = filter_fields(data, TaskState)
                result.append(data)

        # Sort to make the output deterministic.
        result.sort(key=lambda entry: entry["task_id"])
        return {d["task_id"]: d for d in islice(result, option.limit)}
예제 #11
0
    async def list_runtime_envs(self, *, option: ListApiOptions) -> ListApiResponse:
        """List all runtime env information from the cluster.

        Returns:
            A list of runtime env information in the cluster.
            The schema of returned "dict" is equivalent to the
            `RuntimeEnvState` protobuf message.
            We don't have id -> data mapping like other API because runtime env
            doesn't have unique ids.
        """
        agent_ids = self._client.get_all_registered_agent_ids()
        replies = await asyncio.gather(
            *[
                self._client.get_runtime_envs_info(node_id, timeout=option.timeout)
                for node_id in agent_ids
            ],
            return_exceptions=True,
        )

        result = []
        unresponsive_nodes = 0
        for node_id, reply in zip(self._client.get_all_registered_agent_ids(), replies):
            if isinstance(reply, DataSourceUnavailable):
                unresponsive_nodes += 1
                continue
            elif isinstance(reply, Exception):
                raise reply

            states = reply.runtime_env_states
            for state in states:
                data = self._message_to_dict(message=state, fields_to_decode=[])
                # Need to deseiralize this field.
                data["runtime_env"] = RuntimeEnv.deserialize(
                    data["runtime_env"]
                ).to_dict()
                data["node_id"] = node_id
                data = filter_fields(data, RuntimeEnvState)
                result.append(data)

        partial_failure_warning = None
        if len(agent_ids) > 0 and unresponsive_nodes > 0:
            warning_msg = NODE_QUERY_FAILURE_WARNING.format(
                type="agent",
                total=len(agent_ids),
                network_failures=unresponsive_nodes,
                log_command="dashboard_agent.log",
            )
            if unresponsive_nodes == len(agent_ids):
                raise DataSourceUnavailable(warning_msg)
            partial_failure_warning = (
                f"The returned data may contain incomplete result. {warning_msg}"
            )

        # Sort to make the output deterministic.
        def sort_func(entry):
            # If creation time is not there yet (runtime env is failed
            # to be created or not created yet, they are the highest priority.
            # Otherwise, "bigger" creation time is coming first.
            if "creation_time_ms" not in entry:
                return float("inf")
            elif entry["creation_time_ms"] is None:
                return float("inf")
            else:
                return float(entry["creation_time_ms"])

        result.sort(key=sort_func, reverse=True)
        return ListApiResponse(
            result=list(islice(result, option.limit)),
            partial_failure_warning=partial_failure_warning,
        )
예제 #12
0
    async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse:
        """List all object information from the cluster.

        Returns:
            {object_id -> object_data_in_dict}
            object_data_in_dict's schema is in ObjectState
        """
        raylet_ids = self._client.get_all_registered_raylet_ids()
        replies = await asyncio.gather(
            *[
                self._client.get_object_info(node_id, timeout=option.timeout)
                for node_id in raylet_ids
            ],
            return_exceptions=True,
        )

        unresponsive_nodes = 0
        worker_stats = []
        for reply, node_id in zip(replies, raylet_ids):
            if isinstance(reply, DataSourceUnavailable):
                unresponsive_nodes += 1
                continue
            elif isinstance(reply, Exception):
                raise reply

            for core_worker_stat in reply.core_workers_stats:
                # NOTE: Set preserving_proto_field_name=False here because
                # `construct_memory_table` requires a dictionary that has
                # modified protobuf name
                # (e.g., workerId instead of worker_id) as a key.
                worker_stats.append(
                    self._message_to_dict(
                        message=core_worker_stat,
                        fields_to_decode=["object_id"],
                        preserving_proto_field_name=False,
                    )
                )

        partial_failure_warning = None
        if len(raylet_ids) > 0 and unresponsive_nodes > 0:
            warning_msg = NODE_QUERY_FAILURE_WARNING.format(
                type="raylet",
                total=len(raylet_ids),
                network_failures=unresponsive_nodes,
                log_command="raylet.out",
            )
            if unresponsive_nodes == len(raylet_ids):
                raise DataSourceUnavailable(warning_msg)
            partial_failure_warning = (
                f"The returned data may contain incomplete result. {warning_msg}"
            )

        result = []
        memory_table = memory_utils.construct_memory_table(worker_stats)
        for entry in memory_table.table:
            data = entry.as_dict()
            # `construct_memory_table` returns object_ref field which is indeed
            # object_id. We do transformation here.
            # TODO(sang): Refactor `construct_memory_table`.
            data["object_id"] = data["object_ref"]
            del data["object_ref"]
            data = filter_fields(data, ObjectState)
            result.append(data)

        # Sort to make the output deterministic.
        result.sort(key=lambda entry: entry["object_id"])
        return ListApiResponse(
            result={d["object_id"]: d for d in islice(result, option.limit)},
            partial_failure_warning=partial_failure_warning,
        )
예제 #13
0
    async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse:
        """List all task information from the cluster.

        Returns:
            {task_id -> task_data_in_dict}
            task_data_in_dict's schema is in TaskState
        """
        raylet_ids = self._client.get_all_registered_raylet_ids()
        replies = await asyncio.gather(
            *[
                self._client.get_task_info(node_id, timeout=option.timeout)
                for node_id in raylet_ids
            ],
            return_exceptions=True,
        )

        unresponsive_nodes = 0
        running_task_id = set()
        successful_replies = []
        for reply in replies:
            if isinstance(reply, DataSourceUnavailable):
                unresponsive_nodes += 1
                continue
            elif isinstance(reply, Exception):
                raise reply

            successful_replies.append(reply)
            for task_id in reply.running_task_ids:
                running_task_id.add(binary_to_hex(task_id))

        partial_failure_warning = None
        if len(raylet_ids) > 0 and unresponsive_nodes > 0:
            warning_msg = NODE_QUERY_FAILURE_WARNING.format(
                type="raylet",
                total=len(raylet_ids),
                network_failures=unresponsive_nodes,
                log_command="raylet.out",
            )
            if unresponsive_nodes == len(raylet_ids):
                raise DataSourceUnavailable(warning_msg)
            partial_failure_warning = (
                f"The returned data may contain incomplete result. {warning_msg}"
            )

        result = []
        for reply in successful_replies:
            assert not isinstance(reply, Exception)
            tasks = reply.owned_task_info_entries
            for task in tasks:
                data = self._message_to_dict(
                    message=task,
                    fields_to_decode=["task_id"],
                )
                if data["task_id"] in running_task_id:
                    data["scheduling_state"] = TaskStatus.DESCRIPTOR.values_by_number[
                        TaskStatus.RUNNING
                    ].name
                data = filter_fields(data, TaskState)
                result.append(data)

        # Sort to make the output deterministic.
        result.sort(key=lambda entry: entry["task_id"])
        return ListApiResponse(
            result={d["task_id"]: d for d in islice(result, option.limit)},
            partial_failure_warning=partial_failure_warning,
        )