Exemple #1
0
    def task_table(self, task_id=None):
        """Fetch and parse the task table information for one or more task IDs.

        Args:
            task_id: A hex string of the task ID to fetch information about. If
                this is None, then the task object table is fetched.

        Returns:
            Information from the task table.
        """
        self._check_connected()
        if task_id is not None:
            task_id = ray.ObjectID(hex_to_binary(task_id))
            return self._task_table(task_id)
        else:
            task_table_keys = self._keys(
                ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*")
            task_ids_binary = [
                key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):]
                for key in task_table_keys
            ]

            results = {}
            for task_id_binary in task_ids_binary:
                results[binary_to_hex(task_id_binary)] = self._task_table(
                    ray.ObjectID(task_id_binary))
            return results
Exemple #2
0
    def fetch_and_register_remote_function(self, key):
        """Import a remote function."""
        (driver_id_str, function_id_str, function_name, serialized_function,
         num_return_vals, module, resources,
         max_calls) = self._worker.redis_client.hmget(key, [
             "driver_id", "function_id", "name", "function", "num_return_vals",
             "module", "resources", "max_calls"
         ])
        function_id = ray.ObjectID(function_id_str)
        driver_id = ray.ObjectID(driver_id_str)
        function_name = decode(function_name)
        max_calls = int(max_calls)
        module = decode(module)

        # This is a placeholder in case the function can't be unpickled. This
        # will be overwritten if the function is successfully registered.
        def f():
            raise Exception("This function was not imported properly.")

        self._function_execution_info[driver_id][function_id] = (
            FunctionExecutionInfo(function=f,
                                  function_name=function_name,
                                  max_calls=max_calls))
        self._num_task_executions[driver_id][function_id] = 0

        try:
            function = pickle.loads(serialized_function)
        except Exception:
            # If an exception was thrown when the remote function was imported,
            # we record the traceback and notify the scheduler of the failure.
            traceback_str = format_error_message(traceback.format_exc())
            # Log the error message.
            push_error_to_driver(
                self._worker,
                ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
                traceback_str,
                driver_id=driver_id,
                data={
                    "function_id": function_id.id(),
                    "function_name": function_name
                })
        else:
            # The below line is necessary. Because in the driver process,
            # if the function is defined in the file where the python script
            # was started from, its module is `__main__`.
            # However in the worker process, the `__main__` module is a
            # different module, which is `default_worker.py`
            function.__module__ = module
            self._function_execution_info[driver_id][function_id] = (
                FunctionExecutionInfo(function=function,
                                      function_name=function_name,
                                      max_calls=max_calls))
            # Add the function to the function table.
            self._worker.redis_client.rpush(
                b"FunctionTable:" + function_id.id(), self._worker.worker_id)
Exemple #3
0
def test_single_prod_cons_queue(serve_instance):
    q = CentralizedQueues()
    q.link("svc", "backend")

    result_object_id = q.enqueue_request("svc", 1)
    work_object_id = q.dequeue_request("backend")
    got_work = ray.get(ray.ObjectID(work_object_id))
    assert got_work.request_body == 1

    ray.worker.global_worker.put_object(got_work.result_object_id, 2)
    assert ray.get(ray.ObjectID(result_object_id)) == 2
Exemple #4
0
def test_single_prod_cons_queue(serve_instance):
    q = CentralizedQueues()
    q.link("svc", "backend")

    result_object_id = q.enqueue_request("svc", 1, "kwargs", None)
    work_object_id = q.dequeue_request("backend")
    got_work = ray.get(ray.ObjectID(work_object_id))
    assert got_work.request_args == 1
    assert got_work.request_kwargs == "kwargs"

    ray.worker.global_worker.put_object(2, got_work.result_object_id)
    assert ray.get(ray.ObjectID(result_object_id)) == 2
Exemple #5
0
def test_split_traffic(serve_instance):
    q = CentralizedQueues()

    q.enqueue_request("svc", 1, "kwargs", None)
    q.enqueue_request("svc", 1, "kwargs", None)
    q.set_traffic("svc", {})
    work_object_id_1 = q.dequeue_request("backend-1")
    work_object_id_2 = q.dequeue_request("backend-2")
    q.set_traffic("svc", {"backend-1": 0.5, "backend-2": 0.5})

    got_work = ray.get(
        [ray.ObjectID(work_object_id_1),
         ray.ObjectID(work_object_id_2)])
    assert [g.request_args for g in got_work] == [1, 1]
Exemple #6
0
def test_split_traffic(serve_instance):
    q = CentralizedQueues()

    q.set_traffic("svc", {"backend-1": 0.5, "backend-2": 0.5})
    # assume 50% split, the probability of all 20 requests goes to a
    # single queue is 0.5^20 ~ 1-6
    for _ in range(20):
        q.enqueue_request("svc", 1, "kwargs", None)
    work_object_id_1 = q.dequeue_request("backend-1")
    work_object_id_2 = q.dequeue_request("backend-2")

    got_work = ray.get(
        [ray.ObjectID(work_object_id_1),
         ray.ObjectID(work_object_id_2)])
    assert [g.request_args for g in got_work] == [1, 1]
Exemple #7
0
    def flush_profile_data(self):
        """Push the logged profiling data to the global control store.

        By default, profiling information for a given task won't appear in the
        timeline until after the task has completed. For very long-running
        tasks, we may want profiling information to appear more quickly.
        In such cases, this function can be called. Note that as an
        aalternative, we could start thread in the background on workers that
        calls this automatically.
        """
        with self.lock:
            events = self.events
            self.events = []

        if not self.worker.use_raylet:
            event_log_key = b"event_log:" + self.worker.worker_id
            event_log_value = json.dumps(events)
            self.worker.local_scheduler_client.log_event(
                event_log_key, event_log_value, time.time())
        else:
            if self.worker.mode == ray.WORKER_MODE:
                component_type = "worker"
            else:
                component_type = "driver"

            self.worker.local_scheduler_client.push_profile_events(
                component_type, ray.ObjectID(self.worker.worker_id),
                self.worker.node_ip_address, events)
Exemple #8
0
    def remote(self, *args, **kwargs):
        if len(args) != 0:
            raise RayServeException(
                "handle.remote must be invoked with keyword arguments.")

        # get slo_ms before enqueuing the query
        request_slo_ms = kwargs.pop("slo_ms", None)
        if request_slo_ms is not None:
            try:
                request_slo_ms = float(request_slo_ms)
                if request_slo_ms < 0:
                    raise ValueError(
                        "Request SLO must be positive, it is {}".format(
                            request_slo_ms))
            except ValueError as e:
                raise RayServeException(str(e))

        result_object_id_bytes = ray.get(
            self.router_handle.enqueue_request.remote(
                service=self.endpoint_name,
                request_args=(),
                request_kwargs=kwargs,
                request_context=TaskContext.Python,
                request_slo_ms=request_slo_ms))
        return ray.ObjectID(result_object_id_bytes)
Exemple #9
0
def test_ray_serve_mixin(serve_instance):
    q = CentralizedQueuesActor.remote()

    CONSUMER_NAME = "runner-cls"
    PRODUCER_NAME = "prod-cls"

    class MyAdder:
        def __init__(self, inc):
            self.increment = inc

        def __call__(self, flask_request, i=None):
            return i + self.increment

    @ray.remote
    class CustomActor(MyAdder, RayServeMixin):
        pass

    runner = CustomActor.remote(3)

    runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
    runner._ray_serve_main_loop.remote()

    q.link.remote(PRODUCER_NAME, CONSUMER_NAME)

    for query in [333, 444, 555]:
        result_token = ray.ObjectID(
            ray.get(
                q.enqueue_request.remote(
                    PRODUCER_NAME,
                    request_args=None,
                    request_kwargs={"i": query},
                    request_context=context.TaskContext.Python)))
        assert ray.get(result_token) == query + 3
Exemple #10
0
def test_put_pins_object(ray_start_object_store_memory):
    x_id = ray.put("HI")
    x_copy = ray.ObjectID(x_id.binary())
    assert ray.get(x_copy) == "HI"

    # x cannot be evicted since x_id pins it
    for _ in range(10):
        ray.put(np.zeros(10 * 1024 * 1024))
    assert ray.get(x_id) == "HI"
    assert ray.get(x_copy) == "HI"

    # now it can be evicted since x_id pins it but x_copy does not
    del x_id
    for _ in range(10):
        ray.put(np.zeros(10 * 1024 * 1024))
    with pytest.raises(ray.exceptions.UnreconstructableError):
        ray.get(x_copy)

    # weakref put
    y_id = ray.put("HI", weakref=True)
    for _ in range(10):
        ray.put(np.zeros(10 * 1024 * 1024))
    with pytest.raises(ray.exceptions.UnreconstructableError):
        ray.get(y_id)

    @ray.remote
    def check_no_buffer_ref(x):
        assert x[0].get_buffer_ref() is None

    z_id = ray.put("HI")
    assert z_id.get_buffer_ref() is not None
    ray.get(check_no_buffer_ref.remote([z_id]))
Exemple #11
0
def test_object_id_properties():
    id_bytes = b"00112233445566778899"
    object_id = ray.ObjectID(id_bytes)
    assert object_id.binary() == id_bytes
    object_id = ray.ObjectID.nil()
    assert object_id.is_nil()
    with pytest.raises(ValueError, match=r".*needs to have length 20.*"):
        ray.ObjectID(id_bytes + b"1234")
    with pytest.raises(ValueError, match=r".*needs to have length 20.*"):
        ray.ObjectID(b"0123456789")
    object_id = ray.ObjectID.from_random()
    assert not object_id.is_nil()
    assert object_id.binary() != id_bytes
    id_dumps = pickle.dumps(object_id)
    id_from_dumps = pickle.loads(id_dumps)
    assert id_from_dumps == object_id
Exemple #12
0
    def _object_table(self, object_id):
        """Fetch and parse the object table information for a single object ID.

        Args:
            object_id: An object ID to get information about.

        Returns:
            A dictionary with information about the object ID in question.
        """
        # Allow the argument to be either an ObjectID or a hex string.
        if not isinstance(object_id, ray.ObjectID):
            object_id = ray.ObjectID(hex_to_binary(object_id))

        # Return information about a single object ID.
        message = self._execute_command(object_id, "RAY.TABLE_LOOKUP",
                                        ray.gcs_utils.TablePrefix.OBJECT, "",
                                        object_id.binary())
        if message is None:
            return {}
        gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
            message, 0)

        assert gcs_entry.EntriesLength() > 0

        entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData(
            gcs_entry.Entries(0), 0)

        object_info = {
            "DataSize": entry.ObjectSize(),
            "Manager": entry.Manager(),
        }

        return object_info
Exemple #13
0
    async def __call__(self, scope, receive, send):
        # NOTE: This implements ASGI protocol specified in
        #       https://asgi.readthedocs.io/en/latest/specs/index.html

        if scope["type"] == "lifespan":
            await _async_init()
            asyncio.ensure_future(
                self.route_checker(interval=HTTP_ROUTER_CHECKER_INTERVAL_S))
            return

        current_path = scope["path"]
        if current_path == "/":
            await JSONResponse(self.route_table)(scope, receive, send)
        elif current_path in self.route_table:
            endpoint_name = self.route_table[current_path]
            result_object_id_bytes = await as_future(
                self.router.enqueue_request.remote(endpoint_name, scope))
            result = await as_future(ray.ObjectID(result_object_id_bytes))

            if isinstance(result, ray.exceptions.RayTaskError):
                await JSONResponse({
                    "error":
                    "internal error, please use python API to debug"
                })(scope, receive, send)
            else:
                await JSONResponse({"result": result})(scope, receive, send)
        else:
            error_message = ("Path {} not found. "
                             "Please ping http://.../ for routing table"
                             ).format(current_path)

            await JSONResponse({"error": error_message},
                               status_code=404)(scope, receive, send)
Exemple #14
0
    def object_table(self, object_id=None):
        """Fetch and parse the object table info for one or more object IDs.

        Args:
            object_id: An object ID to fetch information about. If this is
                None, then the entire object table is fetched.

        Returns:
            Information from the object table.
        """
        self._check_connected()

        if object_id is not None:
            object_id = ray.ObjectID(hex_to_binary(object_id))
            object_info = self.global_state_accessor.get_object_info(object_id)
            if object_info is None:
                return {}
            else:
                object_location_info = gcs_utils.ObjectLocationInfo.FromString(
                    object_info)
                return self._gen_object_info(object_location_info)
        else:
            object_table = self.global_state_accessor.get_object_table()
            results = {}
            for i in range(len(object_table)):
                object_location_info = gcs_utils.ObjectLocationInfo.FromString(
                    object_table[i])
                results[binary_to_hex(object_location_info.object_id)] = \
                    self._gen_object_info(object_location_info)
            return results
Exemple #15
0
def test_runner_actor(serve_instance):
    q = CentralizedQueuesActor.remote()

    def echo(flask_request, i=None):
        return i

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "prod"

    runner = TaskRunnerActor.remote(echo)

    runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
    runner._ray_serve_main_loop.remote()

    q.link.remote(PRODUCER_NAME, CONSUMER_NAME)

    for query in [333, 444, 555]:
        result_token = ray.ObjectID(
            ray.get(
                q.enqueue_request.remote(
                    PRODUCER_NAME,
                    request_args=None,
                    request_kwargs={"i": query},
                    request_context=context.TaskContext.Python)))
        assert ray.get(result_token) == query
Exemple #16
0
    def function_id(self):
        """Get the function id calculated from this descriptor.

        Returns:
            The value of ray.ObjectID that represents the function id.
        """
        return ray.ObjectID(self._function_id)
Exemple #17
0
 def __init__(self, channel_id_str: str):
     """
     Args:
         channel_id_str: string representation of channel id
     """
     self.channel_id_str = channel_id_str
     self.object_qid = ray.ObjectID(channel_id_str_to_bytes(channel_id_str))
Exemple #18
0
    def fetch_and_execute_function_to_run(self, key):
        """Run on arbitrary function on the worker."""
        (driver_id, serialized_function,
         run_on_other_drivers) = self.redis_client.hmget(
             key, ["driver_id", "function", "run_on_other_drivers"])

        if (utils.decode(run_on_other_drivers) == "False"
                and self.worker.mode == ray.SCRIPT_MODE
                and driver_id != self.worker.task_driver_id.id()):
            return

        try:
            # Deserialize the function.
            function = pickle.loads(serialized_function)
            # Run the function.
            function({"worker": self.worker})
        except Exception:
            # If an exception was thrown when the function was run, we record
            # the traceback and notify the scheduler of the failure.
            traceback_str = traceback.format_exc()
            # Log the error message.
            name = function.__name__ if ("function" in locals() and hasattr(
                function, "__name__")) else ""
            utils.push_error_to_driver(
                self.worker,
                ray_constants.FUNCTION_TO_RUN_PUSH_ERROR,
                traceback_str,
                driver_id=ray.ObjectID(driver_id),
                data={"name": name})
Exemple #19
0
    def _object_table(self, object_id):
        """Fetch and parse the object table information for a single object ID.

        Args:
            object_id: An object ID to get information about.

        Returns:
            A dictionary with information about the object ID in question.
        """
        # Allow the argument to be either an ObjectID or a hex string.
        if not isinstance(object_id, ray.ObjectID):
            object_id = ray.ObjectID(hex_to_binary(object_id))

        # Return information about a single object ID.
        message = self._execute_command(object_id, "RAY.TABLE_LOOKUP",
                                        gcs_utils.TablePrefix.Value("OBJECT"),
                                        "", object_id.binary())
        if message is None:
            return {}
        gcs_entry = gcs_utils.GcsEntry.FromString(message)

        assert len(gcs_entry.entries) > 0

        entry = gcs_utils.ObjectTableData.FromString(gcs_entry.entries[0])

        object_info = {
            "DataSize": entry.object_size,
            "Manager": entry.manager,
        }

        return object_info
Exemple #20
0
    def error_messages(self, job_id=None):
        """Get the error messages for all jobs or a specific job.

        Args:
            job_id: The specific job to get the errors for. If this is None,
                then this method retrieves the errors for all jobs.

        Returns:
            A dictionary mapping job ID to a list of the error messages for
                that job.
        """
        if job_id is not None:
            return self._error_messages(job_id)

        error_table_keys = self.redis_client.keys(
            ray.gcs_utils.TablePrefix_ERROR_INFO_string + "*")
        job_ids = [
            key[len(ray.gcs_utils.TablePrefix_ERROR_INFO_string):]
            for key in error_table_keys
        ]

        return {
            binary_to_hex(job_id): self._error_messages(ray.ObjectID(job_id))
            for job_id in job_ids
        }
Exemple #21
0
def test_task_runner_check_context(serve_instance):
    q = CentralizedQueuesActor.remote()

    def echo(flask_request, i=None):
        # Accessing the flask_request without web context should throw.
        return flask_request.args["i"]

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "producer"

    runner = TaskRunnerActor.remote(echo)

    runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
    runner._ray_serve_main_loop.remote()

    q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
    result_token = ray.ObjectID(
        ray.get(
            q.enqueue_request.remote(
                PRODUCER_NAME,
                request_args=None,
                request_kwargs={"i": 42},
                request_context=context.TaskContext.Python)))

    with pytest.raises(ray.exceptions.RayTaskError):
        ray.get(result_token)
Exemple #22
0
def plasma_free(oids):
    """
    Delete the given ray objects from the local client
    :param oids: The ray objects to delete
    """
    ids = [ray.ObjectID(i) for i in oids]
    ray.internal.free(ids)
    del ids
Exemple #23
0
def test_alter_backend(serve_instance):
    q = CentralizedQueues()

    q.set_traffic("svc", {"backend-1": 1})
    result_object_id = q.enqueue_request("svc", 1, "kwargs", None)
    work_object_id = q.dequeue_request("backend-1")
    got_work = ray.get(ray.ObjectID(work_object_id))
    assert got_work.request_args == 1
    ray.worker.global_worker.put_object(got_work.result_object_id, 2)
    assert ray.get(ray.ObjectID(result_object_id)) == 2

    q.set_traffic("svc", {"backend-2": 1})
    result_object_id = q.enqueue_request("svc", 1, "kwargs", None)
    work_object_id = q.dequeue_request("backend-2")
    got_work = ray.get(ray.ObjectID(work_object_id))
    assert got_work.request_args == 1
    ray.worker.global_worker.put_object(got_work.result_object_id, 2)
    assert ray.get(ray.ObjectID(result_object_id)) == 2
Exemple #24
0
def plasma_prefetch(oids):
    """
    Prefetch the given ray objects to the local client
    :param oids: The ray objects to prefetch
    """
    raylet_client = ray.worker.global_worker.raylet_client
    ids = [ray.ObjectID(i) for i in oids]
    raylet_client.fetch_or_reconstruct(ids, True)
    del ids
Exemple #25
0
def test_alter_backend(serve_instance, task_runner_mock_actor):
    q = RandomPolicyQueue()

    q.set_traffic("svc", {"backend-1": 1})
    result_object_id = q.enqueue_request("svc", 1, "kwargs", None)
    q.dequeue_request("backend-1", task_runner_mock_actor)
    got_work = ray.get(task_runner_mock_actor.get_recent_call.remote())
    assert got_work.request_args == 1
    ray.worker.global_worker.put_object(2, got_work.result_object_id)
    assert ray.get(ray.ObjectID(result_object_id)) == 2

    q.set_traffic("svc", {"backend-2": 1})
    result_object_id = q.enqueue_request("svc", 1, "kwargs", None)
    q.dequeue_request("backend-2", task_runner_mock_actor)
    got_work = ray.get(task_runner_mock_actor.get_recent_call.remote())
    assert got_work.request_args == 1
    ray.worker.global_worker.put_object(2, got_work.result_object_id)
    assert ray.get(ray.ObjectID(result_object_id)) == 2
Exemple #26
0
def fetch(oids):
    if ray.global_state.use_raylet:
        local_sched_client = ray.worker.global_worker.local_scheduler_client
        for o in oids:
            ray_obj_id = ray.ObjectID(o)
            local_sched_client.reconstruct_objects([ray_obj_id], True)
    else:
        for o in oids:
            plasma_id = ray.pyarrow.plasma.ObjectID(o)
            ray.worker.global_worker.plasma_client.fetch([plasma_id])
Exemple #27
0
def _object_table_shard(shard_index):
    redis_client = ray.global_state.redis_clients[shard_index]
    object_table_keys = redis_client.keys(OBJECT_LOCATION_PREFIX + b"*")
    results = {}
    for key in object_table_keys:
        object_id_binary = key[len(OBJECT_LOCATION_PREFIX):]
        results[binary_to_hex(object_id_binary)] = (
            ray.global_state._object_table(ray.ObjectID(object_id_binary)))

    return results
Exemple #28
0
def _task_table_shard(shard_index):
    redis_client = ray.global_state.redis_clients[shard_index]
    task_table_keys = redis_client.keys(TASK_PREFIX + b"*")
    results = {}
    for key in task_table_keys:
        task_id_binary = key[len(TASK_PREFIX):]
        results[binary_to_hex(task_id_binary)] = ray.global_state._task_table(
            ray.ObjectID(task_id_binary))

    return results
Exemple #29
0
    def _ray_serve_main_loop(self, my_handle):
        assert self._ray_serve_setup_completed
        self._ray_serve_self_handle = my_handle

        # Only retrieve the next task if we have completed previous task.
        if self._ray_serve_cached_work_token is None:
            work_token = ray.get(
                self._ray_serve_router_handle.dequeue_request.remote(
                    self._ray_serve_dequeue_requestr_name))
        else:
            work_token = self._ray_serve_cached_work_token

        work_token_id = ray.ObjectID(work_token)
        ready, not_ready = ray.wait([work_token_id],
                                    num_returns=1,
                                    timeout=0.5)
        if len(ready) == 1:
            work_item = ray.get(work_token_id)
            self._ray_serve_cached_work_token = None
        else:
            self._ray_serve_cached_work_token = work_token
            self._ray_serve_self_handle._ray_serve_main_loop.remote(my_handle)
            return

        if work_item.request_context == TaskContext.Web:
            serve_context.web = True
            asgi_scope, body_bytes = work_item.request_args
            flask_request = build_flask_request(asgi_scope, body_bytes)
            args = (flask_request, )
            kwargs = {}
        else:
            serve_context.web = False
            args = (FakeFlaskQuest(), )
            kwargs = work_item.request_kwargs

        result_object_id = work_item.result_object_id

        start_timestamp = time.time()
        try:
            result = self.__call__(*args, **kwargs)
            ray.worker.global_worker.put_object(result_object_id, result)
        except Exception as e:
            wrapped_exception = wrap_to_ray_error(e)
            self._serve_metric_error_counter += 1
            ray.worker.global_worker.put_object(result_object_id,
                                                wrapped_exception)
        self._serve_metric_latency_list.append(time.time() - start_timestamp)

        serve_context.web = False
        # The worker finished one unit of work.
        # It will now tail recursively schedule the main_loop again.

        # TODO(simon): remove tail recursion, ask router to callback instead
        self._ray_serve_self_handle._ray_serve_main_loop.remote(my_handle)
Exemple #30
0
    def remote(self, *args, **kwargs):
        if len(args) != 0:
            raise RayServeException(
                "handle.remote must be invoked with keyword arguments.")

        result_object_id_bytes = ray.get(
            self.router_handle.enqueue_request.remote(
                service=self.endpoint_name,
                request_args=(),
                request_kwargs=kwargs,
                request_context=TaskContext.Python))
        return ray.ObjectID(result_object_id_bytes)