def task_table(self, task_id=None): """Fetch and parse the task table information for one or more task IDs. Args: task_id: A hex string of the task ID to fetch information about. If this is None, then the task object table is fetched. Returns: Information from the task table. """ self._check_connected() if task_id is not None: task_id = ray.ObjectID(hex_to_binary(task_id)) return self._task_table(task_id) else: task_table_keys = self._keys( ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*") task_ids_binary = [ key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):] for key in task_table_keys ] results = {} for task_id_binary in task_ids_binary: results[binary_to_hex(task_id_binary)] = self._task_table( ray.ObjectID(task_id_binary)) return results
def fetch_and_register_remote_function(self, key): """Import a remote function.""" (driver_id_str, function_id_str, function_name, serialized_function, num_return_vals, module, resources, max_calls) = self._worker.redis_client.hmget(key, [ "driver_id", "function_id", "name", "function", "num_return_vals", "module", "resources", "max_calls" ]) function_id = ray.ObjectID(function_id_str) driver_id = ray.ObjectID(driver_id_str) function_name = decode(function_name) max_calls = int(max_calls) module = decode(module) # This is a placeholder in case the function can't be unpickled. This # will be overwritten if the function is successfully registered. def f(): raise Exception("This function was not imported properly.") self._function_execution_info[driver_id][function_id] = ( FunctionExecutionInfo(function=f, function_name=function_name, max_calls=max_calls)) self._num_task_executions[driver_id][function_id] = 0 try: function = pickle.loads(serialized_function) except Exception: # If an exception was thrown when the remote function was imported, # we record the traceback and notify the scheduler of the failure. traceback_str = format_error_message(traceback.format_exc()) # Log the error message. push_error_to_driver( self._worker, ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR, traceback_str, driver_id=driver_id, data={ "function_id": function_id.id(), "function_name": function_name }) else: # The below line is necessary. Because in the driver process, # if the function is defined in the file where the python script # was started from, its module is `__main__`. # However in the worker process, the `__main__` module is a # different module, which is `default_worker.py` function.__module__ = module self._function_execution_info[driver_id][function_id] = ( FunctionExecutionInfo(function=function, function_name=function_name, max_calls=max_calls)) # Add the function to the function table. self._worker.redis_client.rpush( b"FunctionTable:" + function_id.id(), self._worker.worker_id)
def test_single_prod_cons_queue(serve_instance): q = CentralizedQueues() q.link("svc", "backend") result_object_id = q.enqueue_request("svc", 1) work_object_id = q.dequeue_request("backend") got_work = ray.get(ray.ObjectID(work_object_id)) assert got_work.request_body == 1 ray.worker.global_worker.put_object(got_work.result_object_id, 2) assert ray.get(ray.ObjectID(result_object_id)) == 2
def test_single_prod_cons_queue(serve_instance): q = CentralizedQueues() q.link("svc", "backend") result_object_id = q.enqueue_request("svc", 1, "kwargs", None) work_object_id = q.dequeue_request("backend") got_work = ray.get(ray.ObjectID(work_object_id)) assert got_work.request_args == 1 assert got_work.request_kwargs == "kwargs" ray.worker.global_worker.put_object(2, got_work.result_object_id) assert ray.get(ray.ObjectID(result_object_id)) == 2
def test_split_traffic(serve_instance): q = CentralizedQueues() q.enqueue_request("svc", 1, "kwargs", None) q.enqueue_request("svc", 1, "kwargs", None) q.set_traffic("svc", {}) work_object_id_1 = q.dequeue_request("backend-1") work_object_id_2 = q.dequeue_request("backend-2") q.set_traffic("svc", {"backend-1": 0.5, "backend-2": 0.5}) got_work = ray.get( [ray.ObjectID(work_object_id_1), ray.ObjectID(work_object_id_2)]) assert [g.request_args for g in got_work] == [1, 1]
def test_split_traffic(serve_instance): q = CentralizedQueues() q.set_traffic("svc", {"backend-1": 0.5, "backend-2": 0.5}) # assume 50% split, the probability of all 20 requests goes to a # single queue is 0.5^20 ~ 1-6 for _ in range(20): q.enqueue_request("svc", 1, "kwargs", None) work_object_id_1 = q.dequeue_request("backend-1") work_object_id_2 = q.dequeue_request("backend-2") got_work = ray.get( [ray.ObjectID(work_object_id_1), ray.ObjectID(work_object_id_2)]) assert [g.request_args for g in got_work] == [1, 1]
def flush_profile_data(self): """Push the logged profiling data to the global control store. By default, profiling information for a given task won't appear in the timeline until after the task has completed. For very long-running tasks, we may want profiling information to appear more quickly. In such cases, this function can be called. Note that as an aalternative, we could start thread in the background on workers that calls this automatically. """ with self.lock: events = self.events self.events = [] if not self.worker.use_raylet: event_log_key = b"event_log:" + self.worker.worker_id event_log_value = json.dumps(events) self.worker.local_scheduler_client.log_event( event_log_key, event_log_value, time.time()) else: if self.worker.mode == ray.WORKER_MODE: component_type = "worker" else: component_type = "driver" self.worker.local_scheduler_client.push_profile_events( component_type, ray.ObjectID(self.worker.worker_id), self.worker.node_ip_address, events)
def remote(self, *args, **kwargs): if len(args) != 0: raise RayServeException( "handle.remote must be invoked with keyword arguments.") # get slo_ms before enqueuing the query request_slo_ms = kwargs.pop("slo_ms", None) if request_slo_ms is not None: try: request_slo_ms = float(request_slo_ms) if request_slo_ms < 0: raise ValueError( "Request SLO must be positive, it is {}".format( request_slo_ms)) except ValueError as e: raise RayServeException(str(e)) result_object_id_bytes = ray.get( self.router_handle.enqueue_request.remote( service=self.endpoint_name, request_args=(), request_kwargs=kwargs, request_context=TaskContext.Python, request_slo_ms=request_slo_ms)) return ray.ObjectID(result_object_id_bytes)
def test_ray_serve_mixin(serve_instance): q = CentralizedQueuesActor.remote() CONSUMER_NAME = "runner-cls" PRODUCER_NAME = "prod-cls" class MyAdder: def __init__(self, inc): self.increment = inc def __call__(self, flask_request, i=None): return i + self.increment @ray.remote class CustomActor(MyAdder, RayServeMixin): pass runner = CustomActor.remote(3) runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner) runner._ray_serve_main_loop.remote() q.link.remote(PRODUCER_NAME, CONSUMER_NAME) for query in [333, 444, 555]: result_token = ray.ObjectID( ray.get( q.enqueue_request.remote( PRODUCER_NAME, request_args=None, request_kwargs={"i": query}, request_context=context.TaskContext.Python))) assert ray.get(result_token) == query + 3
def test_put_pins_object(ray_start_object_store_memory): x_id = ray.put("HI") x_copy = ray.ObjectID(x_id.binary()) assert ray.get(x_copy) == "HI" # x cannot be evicted since x_id pins it for _ in range(10): ray.put(np.zeros(10 * 1024 * 1024)) assert ray.get(x_id) == "HI" assert ray.get(x_copy) == "HI" # now it can be evicted since x_id pins it but x_copy does not del x_id for _ in range(10): ray.put(np.zeros(10 * 1024 * 1024)) with pytest.raises(ray.exceptions.UnreconstructableError): ray.get(x_copy) # weakref put y_id = ray.put("HI", weakref=True) for _ in range(10): ray.put(np.zeros(10 * 1024 * 1024)) with pytest.raises(ray.exceptions.UnreconstructableError): ray.get(y_id) @ray.remote def check_no_buffer_ref(x): assert x[0].get_buffer_ref() is None z_id = ray.put("HI") assert z_id.get_buffer_ref() is not None ray.get(check_no_buffer_ref.remote([z_id]))
def test_object_id_properties(): id_bytes = b"00112233445566778899" object_id = ray.ObjectID(id_bytes) assert object_id.binary() == id_bytes object_id = ray.ObjectID.nil() assert object_id.is_nil() with pytest.raises(ValueError, match=r".*needs to have length 20.*"): ray.ObjectID(id_bytes + b"1234") with pytest.raises(ValueError, match=r".*needs to have length 20.*"): ray.ObjectID(b"0123456789") object_id = ray.ObjectID.from_random() assert not object_id.is_nil() assert object_id.binary() != id_bytes id_dumps = pickle.dumps(object_id) id_from_dumps = pickle.loads(id_dumps) assert id_from_dumps == object_id
def _object_table(self, object_id): """Fetch and parse the object table information for a single object ID. Args: object_id: An object ID to get information about. Returns: A dictionary with information about the object ID in question. """ # Allow the argument to be either an ObjectID or a hex string. if not isinstance(object_id, ray.ObjectID): object_id = ray.ObjectID(hex_to_binary(object_id)) # Return information about a single object ID. message = self._execute_command(object_id, "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.OBJECT, "", object_id.binary()) if message is None: return {} gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( message, 0) assert gcs_entry.EntriesLength() > 0 entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData( gcs_entry.Entries(0), 0) object_info = { "DataSize": entry.ObjectSize(), "Manager": entry.Manager(), } return object_info
async def __call__(self, scope, receive, send): # NOTE: This implements ASGI protocol specified in # https://asgi.readthedocs.io/en/latest/specs/index.html if scope["type"] == "lifespan": await _async_init() asyncio.ensure_future( self.route_checker(interval=HTTP_ROUTER_CHECKER_INTERVAL_S)) return current_path = scope["path"] if current_path == "/": await JSONResponse(self.route_table)(scope, receive, send) elif current_path in self.route_table: endpoint_name = self.route_table[current_path] result_object_id_bytes = await as_future( self.router.enqueue_request.remote(endpoint_name, scope)) result = await as_future(ray.ObjectID(result_object_id_bytes)) if isinstance(result, ray.exceptions.RayTaskError): await JSONResponse({ "error": "internal error, please use python API to debug" })(scope, receive, send) else: await JSONResponse({"result": result})(scope, receive, send) else: error_message = ("Path {} not found. " "Please ping http://.../ for routing table" ).format(current_path) await JSONResponse({"error": error_message}, status_code=404)(scope, receive, send)
def object_table(self, object_id=None): """Fetch and parse the object table info for one or more object IDs. Args: object_id: An object ID to fetch information about. If this is None, then the entire object table is fetched. Returns: Information from the object table. """ self._check_connected() if object_id is not None: object_id = ray.ObjectID(hex_to_binary(object_id)) object_info = self.global_state_accessor.get_object_info(object_id) if object_info is None: return {} else: object_location_info = gcs_utils.ObjectLocationInfo.FromString( object_info) return self._gen_object_info(object_location_info) else: object_table = self.global_state_accessor.get_object_table() results = {} for i in range(len(object_table)): object_location_info = gcs_utils.ObjectLocationInfo.FromString( object_table[i]) results[binary_to_hex(object_location_info.object_id)] = \ self._gen_object_info(object_location_info) return results
def test_runner_actor(serve_instance): q = CentralizedQueuesActor.remote() def echo(flask_request, i=None): return i CONSUMER_NAME = "runner" PRODUCER_NAME = "prod" runner = TaskRunnerActor.remote(echo) runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner) runner._ray_serve_main_loop.remote() q.link.remote(PRODUCER_NAME, CONSUMER_NAME) for query in [333, 444, 555]: result_token = ray.ObjectID( ray.get( q.enqueue_request.remote( PRODUCER_NAME, request_args=None, request_kwargs={"i": query}, request_context=context.TaskContext.Python))) assert ray.get(result_token) == query
def function_id(self): """Get the function id calculated from this descriptor. Returns: The value of ray.ObjectID that represents the function id. """ return ray.ObjectID(self._function_id)
def __init__(self, channel_id_str: str): """ Args: channel_id_str: string representation of channel id """ self.channel_id_str = channel_id_str self.object_qid = ray.ObjectID(channel_id_str_to_bytes(channel_id_str))
def fetch_and_execute_function_to_run(self, key): """Run on arbitrary function on the worker.""" (driver_id, serialized_function, run_on_other_drivers) = self.redis_client.hmget( key, ["driver_id", "function", "run_on_other_drivers"]) if (utils.decode(run_on_other_drivers) == "False" and self.worker.mode == ray.SCRIPT_MODE and driver_id != self.worker.task_driver_id.id()): return try: # Deserialize the function. function = pickle.loads(serialized_function) # Run the function. function({"worker": self.worker}) except Exception: # If an exception was thrown when the function was run, we record # the traceback and notify the scheduler of the failure. traceback_str = traceback.format_exc() # Log the error message. name = function.__name__ if ("function" in locals() and hasattr( function, "__name__")) else "" utils.push_error_to_driver( self.worker, ray_constants.FUNCTION_TO_RUN_PUSH_ERROR, traceback_str, driver_id=ray.ObjectID(driver_id), data={"name": name})
def _object_table(self, object_id): """Fetch and parse the object table information for a single object ID. Args: object_id: An object ID to get information about. Returns: A dictionary with information about the object ID in question. """ # Allow the argument to be either an ObjectID or a hex string. if not isinstance(object_id, ray.ObjectID): object_id = ray.ObjectID(hex_to_binary(object_id)) # Return information about a single object ID. message = self._execute_command(object_id, "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("OBJECT"), "", object_id.binary()) if message is None: return {} gcs_entry = gcs_utils.GcsEntry.FromString(message) assert len(gcs_entry.entries) > 0 entry = gcs_utils.ObjectTableData.FromString(gcs_entry.entries[0]) object_info = { "DataSize": entry.object_size, "Manager": entry.manager, } return object_info
def error_messages(self, job_id=None): """Get the error messages for all jobs or a specific job. Args: job_id: The specific job to get the errors for. If this is None, then this method retrieves the errors for all jobs. Returns: A dictionary mapping job ID to a list of the error messages for that job. """ if job_id is not None: return self._error_messages(job_id) error_table_keys = self.redis_client.keys( ray.gcs_utils.TablePrefix_ERROR_INFO_string + "*") job_ids = [ key[len(ray.gcs_utils.TablePrefix_ERROR_INFO_string):] for key in error_table_keys ] return { binary_to_hex(job_id): self._error_messages(ray.ObjectID(job_id)) for job_id in job_ids }
def test_task_runner_check_context(serve_instance): q = CentralizedQueuesActor.remote() def echo(flask_request, i=None): # Accessing the flask_request without web context should throw. return flask_request.args["i"] CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" runner = TaskRunnerActor.remote(echo) runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner) runner._ray_serve_main_loop.remote() q.link.remote(PRODUCER_NAME, CONSUMER_NAME) result_token = ray.ObjectID( ray.get( q.enqueue_request.remote( PRODUCER_NAME, request_args=None, request_kwargs={"i": 42}, request_context=context.TaskContext.Python))) with pytest.raises(ray.exceptions.RayTaskError): ray.get(result_token)
def plasma_free(oids): """ Delete the given ray objects from the local client :param oids: The ray objects to delete """ ids = [ray.ObjectID(i) for i in oids] ray.internal.free(ids) del ids
def test_alter_backend(serve_instance): q = CentralizedQueues() q.set_traffic("svc", {"backend-1": 1}) result_object_id = q.enqueue_request("svc", 1, "kwargs", None) work_object_id = q.dequeue_request("backend-1") got_work = ray.get(ray.ObjectID(work_object_id)) assert got_work.request_args == 1 ray.worker.global_worker.put_object(got_work.result_object_id, 2) assert ray.get(ray.ObjectID(result_object_id)) == 2 q.set_traffic("svc", {"backend-2": 1}) result_object_id = q.enqueue_request("svc", 1, "kwargs", None) work_object_id = q.dequeue_request("backend-2") got_work = ray.get(ray.ObjectID(work_object_id)) assert got_work.request_args == 1 ray.worker.global_worker.put_object(got_work.result_object_id, 2) assert ray.get(ray.ObjectID(result_object_id)) == 2
def plasma_prefetch(oids): """ Prefetch the given ray objects to the local client :param oids: The ray objects to prefetch """ raylet_client = ray.worker.global_worker.raylet_client ids = [ray.ObjectID(i) for i in oids] raylet_client.fetch_or_reconstruct(ids, True) del ids
def test_alter_backend(serve_instance, task_runner_mock_actor): q = RandomPolicyQueue() q.set_traffic("svc", {"backend-1": 1}) result_object_id = q.enqueue_request("svc", 1, "kwargs", None) q.dequeue_request("backend-1", task_runner_mock_actor) got_work = ray.get(task_runner_mock_actor.get_recent_call.remote()) assert got_work.request_args == 1 ray.worker.global_worker.put_object(2, got_work.result_object_id) assert ray.get(ray.ObjectID(result_object_id)) == 2 q.set_traffic("svc", {"backend-2": 1}) result_object_id = q.enqueue_request("svc", 1, "kwargs", None) q.dequeue_request("backend-2", task_runner_mock_actor) got_work = ray.get(task_runner_mock_actor.get_recent_call.remote()) assert got_work.request_args == 1 ray.worker.global_worker.put_object(2, got_work.result_object_id) assert ray.get(ray.ObjectID(result_object_id)) == 2
def fetch(oids): if ray.global_state.use_raylet: local_sched_client = ray.worker.global_worker.local_scheduler_client for o in oids: ray_obj_id = ray.ObjectID(o) local_sched_client.reconstruct_objects([ray_obj_id], True) else: for o in oids: plasma_id = ray.pyarrow.plasma.ObjectID(o) ray.worker.global_worker.plasma_client.fetch([plasma_id])
def _object_table_shard(shard_index): redis_client = ray.global_state.redis_clients[shard_index] object_table_keys = redis_client.keys(OBJECT_LOCATION_PREFIX + b"*") results = {} for key in object_table_keys: object_id_binary = key[len(OBJECT_LOCATION_PREFIX):] results[binary_to_hex(object_id_binary)] = ( ray.global_state._object_table(ray.ObjectID(object_id_binary))) return results
def _task_table_shard(shard_index): redis_client = ray.global_state.redis_clients[shard_index] task_table_keys = redis_client.keys(TASK_PREFIX + b"*") results = {} for key in task_table_keys: task_id_binary = key[len(TASK_PREFIX):] results[binary_to_hex(task_id_binary)] = ray.global_state._task_table( ray.ObjectID(task_id_binary)) return results
def _ray_serve_main_loop(self, my_handle): assert self._ray_serve_setup_completed self._ray_serve_self_handle = my_handle # Only retrieve the next task if we have completed previous task. if self._ray_serve_cached_work_token is None: work_token = ray.get( self._ray_serve_router_handle.dequeue_request.remote( self._ray_serve_dequeue_requestr_name)) else: work_token = self._ray_serve_cached_work_token work_token_id = ray.ObjectID(work_token) ready, not_ready = ray.wait([work_token_id], num_returns=1, timeout=0.5) if len(ready) == 1: work_item = ray.get(work_token_id) self._ray_serve_cached_work_token = None else: self._ray_serve_cached_work_token = work_token self._ray_serve_self_handle._ray_serve_main_loop.remote(my_handle) return if work_item.request_context == TaskContext.Web: serve_context.web = True asgi_scope, body_bytes = work_item.request_args flask_request = build_flask_request(asgi_scope, body_bytes) args = (flask_request, ) kwargs = {} else: serve_context.web = False args = (FakeFlaskQuest(), ) kwargs = work_item.request_kwargs result_object_id = work_item.result_object_id start_timestamp = time.time() try: result = self.__call__(*args, **kwargs) ray.worker.global_worker.put_object(result_object_id, result) except Exception as e: wrapped_exception = wrap_to_ray_error(e) self._serve_metric_error_counter += 1 ray.worker.global_worker.put_object(result_object_id, wrapped_exception) self._serve_metric_latency_list.append(time.time() - start_timestamp) serve_context.web = False # The worker finished one unit of work. # It will now tail recursively schedule the main_loop again. # TODO(simon): remove tail recursion, ask router to callback instead self._ray_serve_self_handle._ray_serve_main_loop.remote(my_handle)
def remote(self, *args, **kwargs): if len(args) != 0: raise RayServeException( "handle.remote must be invoked with keyword arguments.") result_object_id_bytes = ray.get( self.router_handle.enqueue_request.remote( service=self.endpoint_name, request_args=(), request_kwargs=kwargs, request_context=TaskContext.Python)) return ray.ObjectID(result_object_id_bytes)