Example #1
0
    def _call_schedule_for_task(self, task: ray_client_pb2.ClientTask,
                                num_returns: int) -> List[Future]:
        logger.debug("Scheduling %s" % task)
        task.client_id = self._client_id
        if num_returns is None:
            num_returns = 1

        id_futures = [Future() for _ in range(num_returns)]

        def populate_ids(
                resp: Union[ray_client_pb2.DataResponse, Exception]) -> None:
            if isinstance(resp, Exception):
                if isinstance(resp, grpc.RpcError):
                    resp = decode_exception(resp)
                for future in id_futures:
                    future.set_exception(resp)
                return

            ticket = resp.task_ticket
            if not ticket.valid:
                try:
                    ex = cloudpickle.loads(ticket.error)
                except (pickle.UnpicklingError, TypeError) as e_new:
                    ex = e_new
                for future in id_futures:
                    future.set_exception(ex)
                return

            if len(ticket.return_ids) != num_returns:
                exc = ValueError(
                    f"Expected {num_returns} returns but received "
                    f"{len(ticket.return_ids)}")
                for future, raw_id in zip(id_futures, ticket.return_ids):
                    future.set_exception(exc)
                return

            for future, raw_id in zip(id_futures, ticket.return_ids):
                future.set_result(raw_id)

        self.data_client.Schedule(task, populate_ids)

        self.total_outbound_message_size_bytes += task.ByteSize()
        if self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD \
                and log_once("client_communication_overhead_warning"):
            warnings.warn(
                "More than 10MB of messages have been created to schedule "
                "tasks on the server. This can be slow on Ray Client due to "
                "communication overhead over the network. If you're running "
                "many fine-grained tasks, consider running them inside a "
                "single remote function. See the section on \"Too "
                "fine-grained tasks\" in the Ray Design Patterns document for "
                f"more details: {DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK}. If "
                "your functions frequently use large objects, consider "
                "storing the objects remotely with ray.put. An example of "
                "this is shown in the \"Closure capture of large / "
                "unserializable object\" section of the Ray Design Patterns "
                "document, available here: "
                f"{DESIGN_PATTERN_LARGE_OBJECTS_LINK}", UserWarning)
        return id_futures
Example #2
0
 def _call_schedule_for_task(
         self, task: ray_client_pb2.ClientTask) -> List[bytes]:
     logger.debug("Scheduling %s" % task)
     task.client_id = self._client_id
     try:
         ticket = self.server.Schedule(task, metadata=self.metadata)
     except grpc.RpcError as e:
         raise decode_exception(e.details)
     if not ticket.valid:
         raise cloudpickle.loads(ticket.error)
     return ticket.return_ids
Example #3
0
    def _call_schedule_for_task(self, task: ray_client_pb2.ClientTask,
                                num_returns: int) -> List[bytes]:
        logger.debug("Scheduling %s" % task)
        task.client_id = self._client_id
        metadata = self._add_ids_to_metadata(self.metadata)
        if num_returns is None:
            num_returns = 1

        try:
            ticket = self._call_stub("Schedule", task, metadata=metadata)
        except grpc.RpcError as e:
            raise decode_exception(e)

        if not ticket.valid:
            try:
                raise cloudpickle.loads(ticket.error)
            except (pickle.UnpicklingError, TypeError):
                logger.exception("Failed to deserialize {}".format(
                    ticket.error))
                raise
        self.total_num_tasks_scheduled += 1
        self.total_outbound_message_size_bytes += task.ByteSize()
        if self.total_num_tasks_scheduled > TASK_WARNING_THRESHOLD and \
                log_once("client_communication_overhead_warning"):
            warnings.warn(
                f"More than {TASK_WARNING_THRESHOLD} remote tasks have been "
                "scheduled. This can be slow on Ray Client due to "
                "communication overhead over the network. If you're running "
                "many fine-grained tasks, consider running them in a single "
                "remote function. See the section on \"Too fine-grained "
                "tasks\" in the Ray Design Patterns document for more "
                f"details: {DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK}",
                UserWarning)
        if self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD \
                and log_once("client_communication_overhead_warning"):
            warnings.warn(
                "More than 10MB of messages have been created to schedule "
                "tasks on the server. This can be slow on Ray Client due to "
                "communication overhead over the network. If you're running "
                "many fine-grained tasks, consider running them inside a "
                "single remote function. See the section on \"Too "
                "fine-grained tasks\" in the Ray Design Patterns document for "
                f"more details: {DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK}. If "
                "your functions frequently use large objects, consider "
                "storing the objects remotely with ray.put. An example of "
                "this is shown in the \"Closure capture of large / "
                "unserializable object\" section of the Ray Design Patterns "
                "document, available here: "
                f"{DESIGN_PATTERN_LARGE_OBJECTS_LINK}", UserWarning)
        if num_returns != len(ticket.return_ids):
            raise TypeError("Unexpected number of returned values. Expected "
                            f"{num_returns} actual {ticket.return_ids}")
        return ticket.return_ids