コード例 #1
0
ファイル: registry.py プロジェクト: jamescasbon/ray
 def get(self, category, key):
     if _internal_kv_initialized():
         value = _internal_kv_get(_make_key(category, key))
         if value is None:
             raise ValueError(
                 "Registry value for {}/{} doesn't exist.".format(
                     category, key))
         return pickle.loads(value)
     else:
         return pickle.loads(self._to_flush[(category, key)])
コード例 #2
0
ファイル: import_thread.py プロジェクト: robertnishihara/ray
    def fetch_and_execute_function_to_run(self, key):
        """Run on arbitrary function on the worker."""
        (driver_id, serialized_function,
         run_on_other_drivers) = self.redis_client.hmget(
             key, ["driver_id", "function", "run_on_other_drivers"])

        if (utils.decode(run_on_other_drivers) == "False"
                and self.worker.mode == ray.SCRIPT_MODE
                and driver_id != self.worker.task_driver_id.binary()):
            return

        try:
            # Deserialize the function.
            function = pickle.loads(serialized_function)
            # Run the function.
            function({"worker": self.worker})
        except Exception:
            # If an exception was thrown when the function was run, we record
            # the traceback and notify the scheduler of the failure.
            traceback_str = traceback.format_exc()
            # Log the error message.
            utils.push_error_to_driver(
                self.worker,
                ray_constants.FUNCTION_TO_RUN_PUSH_ERROR,
                traceback_str,
                driver_id=ray.DriverID(driver_id))
コード例 #3
0
ファイル: monitor.py プロジェクト: jamescasbon/ray
    def _maybe_flush_gcs(self):
        """Experimental: issue a flush request to the GCS.

        The purpose of this feature is to control GCS memory usage.

        To activate this feature, Ray must be compiled with the flag
        RAY_USE_NEW_GCS set, and Ray must be started at run time with the flag
        as well.
        """
        if not self.issue_gcs_flushes:
            return
        if self.gcs_flush_policy is None:
            serialized = self.redis.get("gcs_flushing_policy")
            if serialized is None:
                # Client has not set any policy; by default flushing is off.
                return
            self.gcs_flush_policy = pickle.loads(serialized)

        if not self.gcs_flush_policy.should_flush(self.redis_shard):
            return

        max_entries_to_flush = self.gcs_flush_policy.num_entries_to_flush()
        num_flushed = self.redis_shard.execute_command(
            "HEAD.FLUSH {}".format(max_entries_to_flush))
        logger.info("num_flushed {}".format(num_flushed))

        # This flushes event log and log files.
        ray.experimental.flush_redis_unsafe(self.redis)

        self.gcs_flush_policy.record_flush()
コード例 #4
0
    def _load_actor_class_from_gcs(self, driver_id, function_descriptor):
        """Load actor class from GCS."""
        key = (b"ActorClass:" + driver_id.binary() + b":" +
               function_descriptor.function_id.binary())
        # Wait for the actor class key to have been imported by the
        # import thread. TODO(rkn): It shouldn't be possible to end
        # up in an infinite loop here, but we should push an error to
        # the driver if too much time is spent here.
        while key not in self.imported_actor_classes:
            time.sleep(0.001)

        # Fetch raw data from GCS.
        (driver_id_str, class_name, module, pickled_class,
         actor_method_names) = self._worker.redis_client.hmget(
             key, [
                 "driver_id", "class_name", "module", "class",
                 "actor_method_names"
             ])

        class_name = ensure_str(class_name)
        module_name = ensure_str(module)
        driver_id = ray.DriverID(driver_id_str)
        actor_method_names = json.loads(ensure_str(actor_method_names))

        actor_class = None
        try:
            with self._worker.lock:
                actor_class = pickle.loads(pickled_class)
        except Exception:
            logger.exception(
                "Failed to load actor class %s.".format(class_name))
            # The actor class failed to be unpickled, create a fake actor
            # class instead (just to produce error messages and to prevent
            # the driver from hanging).
            actor_class = self._create_fake_actor_class(
                class_name, actor_method_names)
            # If an exception was thrown when the actor was imported, we record
            # the traceback and notify the scheduler of the failure.
            traceback_str = ray.utils.format_error_message(
                traceback.format_exc())
            # Log the error message.
            push_error_to_driver(
                self._worker, ray_constants.REGISTER_ACTOR_PUSH_ERROR,
                "Failed to unpickle actor class '{}' for actor ID {}. "
                "Traceback:\n{}".format(class_name,
                                        self._worker.actor_id.hex(),
                                        traceback_str), driver_id)
            # TODO(rkn): In the future, it might make sense to have the worker
            # exit here. However, currently that would lead to hanging if
            # someone calls ray.get on a method invoked on the actor.

        # The below line is necessary. Because in the driver process,
        # if the function is defined in the file where the python script
        # was started from, its module is `__main__`.
        # However in the worker process, the `__main__` module is a
        # different module, which is `default_worker.py`
        actor_class.__module__ = module_name
        return actor_class
コード例 #5
0
    def fetch_and_register_remote_function(self, key):
        """Import a remote function."""
        (driver_id_str, function_id_str, function_name, serialized_function,
         num_return_vals, module, resources,
         max_calls) = self._worker.redis_client.hmget(key, [
             "driver_id", "function_id", "name", "function", "num_return_vals",
             "module", "resources", "max_calls"
         ])
        function_id = ray.FunctionID(function_id_str)
        driver_id = ray.DriverID(driver_id_str)
        function_name = decode(function_name)
        max_calls = int(max_calls)
        module = decode(module)

        # This is a placeholder in case the function can't be unpickled. This
        # will be overwritten if the function is successfully registered.
        def f():
            raise Exception("This function was not imported properly.")

        self._function_execution_info[driver_id][function_id] = (
            FunctionExecutionInfo(
                function=f, function_name=function_name, max_calls=max_calls))
        self._num_task_executions[driver_id][function_id] = 0

        try:
            function = pickle.loads(serialized_function)
        except Exception:
            # If an exception was thrown when the remote function was imported,
            # we record the traceback and notify the scheduler of the failure.
            traceback_str = format_error_message(traceback.format_exc())
            # Log the error message.
            push_error_to_driver(
                self._worker,
                ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
                "Failed to unpickle the remote function '{}' with function ID "
                "{}. Traceback:\n{}".format(function_name, function_id.hex(),
                                            traceback_str),
                driver_id=driver_id)
        else:
            # The below line is necessary. Because in the driver process,
            # if the function is defined in the file where the python script
            # was started from, its module is `__main__`.
            # However in the worker process, the `__main__` module is a
            # different module, which is `default_worker.py`
            function.__module__ = module
            self._function_execution_info[driver_id][function_id] = (
                FunctionExecutionInfo(
                    function=function,
                    function_name=function_name,
                    max_calls=max_calls))
            # Add the function to the function table.
            self._worker.redis_client.rpush(
                b"FunctionTable:" + function_id.binary(),
                self._worker.worker_id)
コード例 #6
0
ファイル: trial.py プロジェクト: robertnishihara/ray
    def __setstate__(self, state):
        logger_started = state.pop("__logger_started__")
        state["resources"] = json_to_resources(state["resources"])
        if state["status"] == Trial.RUNNING:
            state["status"] = Trial.PENDING
        for key in self._nonjson_fields:
            state[key] = cloudpickle.loads(hex_to_binary(state[key]))

        self.__dict__.update(state)
        Trial._registration_check(self.trainable_name)
        if logger_started:
            self.init_logger()
コード例 #7
0
ファイル: trial.py プロジェクト: jamescasbon/ray
    def __setstate__(self, state):
        logger_started = state.pop("__logger_started__")
        state["resources"] = json_to_resources(state["resources"])
        for key in [
                "_checkpoint", "config", "custom_loggers", "sync_function",
                "last_result"
        ]:
            state[key] = cloudpickle.loads(hex_to_binary(state[key]))

        self.__dict__.update(state)
        Trial._registration_check(self.trainable_name)
        if logger_started:
            self.init_logger()
コード例 #8
0
ファイル: named_actors.py プロジェクト: jamescasbon/ray
def get_actor(name):
    """Get a named actor which was previously created.

    If the actor doesn't exist, an exception will be raised.

    Args:
        name: The name of the named actor.

    Returns:
        The ActorHandle object corresponding to the name.
    """
    actor_name = _calculate_key(name)
    pickled_state = _internal_kv_get(actor_name)
    if pickled_state is None:
        raise ValueError("The actor with name={} doesn't exist".format(name))
    handle = pickle.loads(pickled_state)
    return handle
コード例 #9
0
ファイル: trial_runner.py プロジェクト: linwukang/ray
 def _from_cloudpickle(self, obj):
     return cloudpickle.loads(hex_to_binary(obj["value"]))
コード例 #10
0
ファイル: signal.py プロジェクト: robertnishihara/ray
def receive(sources, timeout=None):
    """Get all outstanding signals from sources.

    A source can be either (1) an object ID returned by the task (we want
    to receive signals from), or (2) an actor handle.

    When invoked by the same entity E (where E can be an actor, task or
    driver), for each source S in sources, this function returns all signals
    generated by S since the last receive() was invoked by E on S. If this is
    the first call on S, this function returns all past signals generated by S
    so far. Note that different actors, tasks or drivers that call receive()
    on the same source S will get independent copies of the signals generated
    by S.

    Args:
        sources: List of sources from which the caller waits for signals.
            A source is either an object ID returned by a task (in this case
            the object ID is used to identify that task), or an actor handle.
            If the user passes the IDs of multiple objects returned by the
            same task, this function returns a copy of the signals generated
            by that task for each object ID.
        timeout: Maximum time (in seconds) this function waits to get a signal
            from a source in sources. If None, the timeout is infinite.

    Returns:
        A list of pairs (S, sig), where S is a source in the sources argument,
            and sig is a signal generated by S since the last time receive()
            was called on S. Thus, for each S in sources, the return list can
            contain zero or multiple entries.
    """

    # If None, initialize the timeout to a huge value (i.e., over 30,000 years
    # in this case) to "approximate" infinity.
    if timeout is None:
        timeout = 10**12

    if timeout < 0:
        raise ValueError("The 'timeout' argument cannot be less than 0.")

    if not hasattr(ray.worker.global_worker, "signal_counters"):
        ray.worker.global_worker.signal_counters = defaultdict(lambda: b"0")

    signal_counters = ray.worker.global_worker.signal_counters

    # Map the ID of each source task to the source itself.
    task_id_to_sources = defaultdict(lambda: [])
    for s in sources:
        task_id_to_sources[_get_task_id(s).hex()].append(s)

    # Construct the redis query.
    query = "XREAD BLOCK "
    # Multiply by 1000x since timeout is in sec and redis expects ms.
    query += str(1000 * timeout)
    query += " STREAMS "
    query += " ".join([task_id for task_id in task_id_to_sources])
    query += " "
    query += " ".join([
        ray.utils.decode(signal_counters[ray.utils.hex_to_binary(task_id)])
        for task_id in task_id_to_sources
    ])

    answers = ray.worker.global_worker.redis_client.execute_command(query)
    if not answers:
        return []

    results = []
    # Decoding is a little bit involved. Iterate through all the answers:
    for i, answer in enumerate(answers):
        # Make sure the answer corresponds to a source, s, in sources.
        task_id = ray.utils.decode(answer[0])
        task_source_list = task_id_to_sources[task_id]
        # The list of results for source s is stored in answer[1]
        for r in answer[1]:
            for s in task_source_list:
                # Now it gets tricky: r[0] is the redis internal sequence id
                signal_counters[ray.utils.hex_to_binary(task_id)] = r[0]
                # r[1] contains a list with elements (key, value), in our case
                # we only have one key "signal" and the value is the signal.
                signal = cloudpickle.loads(ray.utils.hex_to_binary(r[1][1]))
                results.append((s, signal))

    return results
コード例 #11
0
 def _load_trial_info(self, trial_info):
     trial_info["config"] = cloudpickle.loads(
         hex_to_binary(trial_info["config"]))
     trial_info["result"] = cloudpickle.loads(
         hex_to_binary(trial_info["result"]))
コード例 #12
0
ファイル: function_manager.py プロジェクト: wuisawesome/ray
    def _load_actor_class_from_gcs(self, job_id,
                                   actor_creation_function_descriptor):
        """Load actor class from GCS."""
        key = make_function_table_key(
            b"ActorClass",
            job_id,
            actor_creation_function_descriptor.function_id.binary(),
        )
        # Only wait for the actor class if it was exported from the same job.
        # It will hang if the job id mismatches, since we isolate actor class
        # exports from the import thread. It's important to wait since this
        # guarantees import order, though we fetch the actor class directly.
        # Import order isn't important across jobs, as we only need to fetch
        # the class for `ray.get_actor()`.
        if job_id.binary() == self._worker.current_job_id.binary():
            # Wait for the actor class key to have been imported by the
            # import thread. TODO(rkn): It shouldn't be possible to end
            # up in an infinite loop here, but we should push an error to
            # the driver if too much time is spent here.
            while key not in self.imported_actor_classes:
                try:
                    # If we're in the process of deserializing an ActorHandle
                    # and we hold the function_manager lock, we may be blocking
                    # the import_thread from loading the actor class. Use wait
                    # to temporarily yield control to the import thread.
                    self.cv.wait()
                except RuntimeError:
                    # We don't hold the function_manager lock, just sleep
                    time.sleep(0.001)

        # Fetch raw data from GCS.
        vals = self._worker.gcs_client.internal_kv_get(
            key, KV_NAMESPACE_FUNCTION_TABLE)
        fields = [
            "job_id", "class_name", "module", "class", "actor_method_names"
        ]
        if vals is None:
            vals = {}
        else:
            vals = pickle.loads(vals)
        (job_id_str, class_name, module, pickled_class,
         actor_method_names) = (vals.get(field) for field in fields)

        class_name = ensure_str(class_name)
        module_name = ensure_str(module)
        job_id = ray.JobID(job_id_str)
        actor_method_names = json.loads(ensure_str(actor_method_names))

        actor_class = None
        try:
            with self.lock:
                actor_class = pickle.loads(pickled_class)
        except Exception:
            logger.debug("Failed to load actor class %s.", class_name)
            # If an exception was thrown when the actor was imported, we record
            # the traceback and notify the scheduler of the failure.
            traceback_str = format_error_message(traceback.format_exc())
            # The actor class failed to be unpickled, create a fake actor
            # class instead (just to produce error messages and to prevent
            # the driver from hanging).
            actor_class = self._create_fake_actor_class(
                class_name, actor_method_names, traceback_str)

        # The below line is necessary. Because in the driver process,
        # if the function is defined in the file where the python script
        # was started from, its module is `__main__`.
        # However in the worker process, the `__main__` module is a
        # different module, which is `default_worker.py`
        actor_class.__module__ = module_name
        return actor_class
コード例 #13
0
def fetch_and_register_actor(actor_class_key, worker):
    """Import an actor.

    This will be called by the worker's import thread when the worker receives
    the actor_class export, assuming that the worker is an actor for that
    class.

    Args:
        actor_class_key: The key in Redis to use to fetch the actor.
        worker: The worker to use.
    """
    actor_id_str = worker.actor_id
    (driver_id, class_id, class_name, module, pickled_class,
     checkpoint_interval,
     actor_method_names) = worker.redis_client.hmget(actor_class_key, [
         "driver_id", "class_id", "class_name", "module", "class",
         "checkpoint_interval", "actor_method_names"
     ])

    class_name = decode(class_name)
    module = decode(module)
    checkpoint_interval = int(checkpoint_interval)
    actor_method_names = json.loads(decode(actor_method_names))

    # Create a temporary actor with some temporary methods so that if the actor
    # fails to be unpickled, the temporary actor can be used (just to produce
    # error messages and to prevent the driver from hanging).
    class TemporaryActor(object):
        pass

    worker.actors[actor_id_str] = TemporaryActor()
    worker.actor_checkpoint_interval = checkpoint_interval

    def temporary_actor_method(*xs):
        raise Exception("The actor with name {} failed to be imported, and so "
                        "cannot execute this method".format(class_name))

    # Register the actor method executors.
    for actor_method_name in actor_method_names:
        function_id = compute_actor_method_function_id(class_name,
                                                       actor_method_name).id()
        temporary_executor = make_actor_method_executor(worker,
                                                        actor_method_name,
                                                        temporary_actor_method,
                                                        actor_imported=False)
        worker.function_execution_info[driver_id][function_id] = (
            ray.worker.FunctionExecutionInfo(function=temporary_executor,
                                             function_name=actor_method_name,
                                             max_calls=0))
        worker.num_task_executions[driver_id][function_id] = 0

    try:
        unpickled_class = pickle.loads(pickled_class)
        worker.actor_class = unpickled_class
    except Exception:
        # If an exception was thrown when the actor was imported, we record the
        # traceback and notify the scheduler of the failure.
        traceback_str = ray.utils.format_error_message(traceback.format_exc())
        # Log the error message.
        push_error_to_driver(worker,
                             ray_constants.REGISTER_ACTOR_PUSH_ERROR,
                             traceback_str,
                             driver_id,
                             data={"actor_id": actor_id_str})
        # TODO(rkn): In the future, it might make sense to have the worker exit
        # here. However, currently that would lead to hanging if someone calls
        # ray.get on a method invoked on the actor.
    else:
        # TODO(pcm): Why is the below line necessary?
        unpickled_class.__module__ = module
        worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)

        def pred(x):
            return (inspect.isfunction(x) or inspect.ismethod(x)
                    or is_cython(x))

        actor_methods = inspect.getmembers(unpickled_class, predicate=pred)
        for actor_method_name, actor_method in actor_methods:
            function_id = compute_actor_method_function_id(
                class_name, actor_method_name).id()
            executor = make_actor_method_executor(worker,
                                                  actor_method_name,
                                                  actor_method,
                                                  actor_imported=True)
            worker.function_execution_info[driver_id][function_id] = (
                ray.worker.FunctionExecutionInfo(
                    function=executor,
                    function_name=actor_method_name,
                    max_calls=0))
コード例 #14
0
ファイル: api.py プロジェクト: tchordia/ray
    def decorator(cls):
        if not inspect.isclass(cls):
            raise ValueError("@serve.ingress must be used with a class.")

        if issubclass(cls, collections.abc.Callable):
            raise ValueError(
                "Class passed to @serve.ingress may not have __call__ method."
            )

        # Sometimes there are decorators on the methods. We want to fix
        # the fast api routes here.
        if isinstance(app, (FastAPI, APIRouter)):
            make_fastapi_class_based_view(app, cls)

        # Free the state of the app so subsequent modification won't affect
        # this ingress deployment. We don't use copy.copy here to avoid
        # recursion issue.
        ensure_serialization_context()
        frozen_app = cloudpickle.loads(cloudpickle.dumps(app))

        class ASGIAppWrapper(cls):
            async def __init__(self, *args, **kwargs):
                super().__init__(*args, **kwargs)

                install_serve_encoders_to_fastapi()

                self._serve_app = frozen_app

                # Use uvicorn's lifespan handling code to properly deal with
                # startup and shutdown event.
                self._serve_asgi_lifespan = LifespanOn(
                    Config(self._serve_app, lifespan="on")
                )
                # Replace uvicorn logger with our own.
                self._serve_asgi_lifespan.logger = logger
                # LifespanOn's logger logs in INFO level thus becomes spammy
                # Within this block we temporarily uplevel for cleaner logging
                with LoggingContext(
                    self._serve_asgi_lifespan.logger, level=logging.WARNING
                ):
                    await self._serve_asgi_lifespan.startup()

            async def __call__(self, request: Request):
                sender = ASGIHTTPSender()
                await self._serve_app(
                    request.scope,
                    request.receive,
                    sender,
                )
                return sender.build_asgi_response()

            # NOTE: __del__ must be async so that we can run asgi shutdown
            # in the same event loop.
            async def __del__(self):
                # LifespanOn's logger logs in INFO level thus becomes spammy
                # Within this block we temporarily uplevel for cleaner logging
                with LoggingContext(
                    self._serve_asgi_lifespan.logger, level=logging.WARNING
                ):
                    await self._serve_asgi_lifespan.shutdown()

                # Make sure to call user's del method as well.
                super_cls = super()
                if hasattr(super_cls, "__del__"):
                    super_cls.__del__()

        ASGIAppWrapper.__name__ = cls.__name__
        return ASGIAppWrapper
コード例 #15
0
        async def __init__(self, deployment_name, replica_tag, init_args,
                           init_kwargs, deployment_config_proto_bytes: bytes,
                           version: DeploymentVersion, controller_name: str,
                           detached: bool):
            deployment_def = cloudpickle.loads(serialized_deployment_def)
            deployment_config = DeploymentConfig.from_proto_bytes(
                deployment_config_proto_bytes)

            if inspect.isfunction(deployment_def):
                is_function = True
            elif inspect.isclass(deployment_def):
                is_function = False
            else:
                assert False, ("deployment_def must be function, class, or "
                               "corresponding import path.")

            # Set the controller name so that serve.connect() in the user's
            # code will connect to the instance that this deployment is running
            # in.
            ray.serve.api._set_internal_replica_context(deployment_name,
                                                        replica_tag,
                                                        controller_name,
                                                        servable_object=None)

            assert controller_name, "Must provide a valid controller_name"

            controller_namespace = ray.serve.api._get_controller_namespace(
                detached)
            controller_handle = ray.get_actor(controller_name,
                                              namespace=controller_namespace)

            # This closure initializes user code and finalizes replica
            # startup. By splitting the initialization step like this,
            # we can already access this actor before the user code
            # has finished initializing.
            # The supervising state manager can then wait
            # for allocation of this replica by using the `is_allocated`
            # method. After that, it calls `reconfigure` to trigger
            # user code initialization.
            async def initialize_replica():
                if is_function:
                    _callable = deployment_def
                else:
                    # This allows deployments to define an async __init__
                    # method (required for FastAPI).
                    _callable = deployment_def.__new__(deployment_def)
                    await sync_to_async(_callable.__init__)(*init_args,
                                                            **init_kwargs)
                # Setting the context again to update the servable_object.
                ray.serve.api._set_internal_replica_context(
                    deployment_name,
                    replica_tag,
                    controller_name,
                    servable_object=_callable)

                self.replica = RayServeReplica(_callable, deployment_name,
                                               replica_tag, deployment_config,
                                               deployment_config.user_config,
                                               version, is_function,
                                               controller_handle)

            # Is it fine that replica is None here?
            # Should we add a check in all methods that use self.replica
            # or, alternatively, create an async get_replica() method?
            self.replica = None
            self._initialize_replica = initialize_replica

            # asyncio.Event used to signal that the replica is shutting down.
            self.shutdown_event = asyncio.Event()
コード例 #16
0
 def PutObject(self, request, context=None):
     obj = cloudpickle.loads(request.data)
     objectref = ray.put(obj)
     self.object_refs[objectref.binary()] = objectref
     logger.info("put: %s" % objectref)
     return ray_client_pb2.PutResponse(id=objectref.binary())
コード例 #17
0
 def object_hook(self, obj):
     if obj.get("_type") == "function":
         return cloudpickle.loads(hex_to_binary(obj["value"]))
     return obj
コード例 #18
0
    def __init__(self, checkpoint: bytes = None):
        self.routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict()
        self.traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict()

        if checkpoint is not None:
            self.routes, self.traffic_policies = pickle.loads(checkpoint)
コード例 #19
0
ファイル: signal.py プロジェクト: anke522/ray-1
def receive(sources, timeout=10**12):
    """Get all outstanding signals from sources.

    A source can be either (1) an object id returned by the task (we want
    to receive signals from), or (2) an actor handle.

    For each source S, this function returns all signals associated to S
    since the last receive() or forget() were invoked on S. If this is the
    first call on S, this function returns all past signals generated by S
    so far.

    Args:
        sources: list of sources from which caller waits for signals.
            A source is either an object id identifying the task returning
            the object, or an actor handle.
        timeout: Time (in seconds) this function waits to get a signal from
            a source in sources. If none, return when timeout expires.

    Returns:
        The list of signals generated by each source in sources.
        This list contain pairs (source, signal). There can be
        more than a signal associated with the same source.
    """
    if not hasattr(ray.worker.global_worker, "signal_counters"):
        ray.worker.global_worker.signal_counters = defaultdict(lambda: b"0")

    signal_counters = ray.worker.global_worker.signal_counters

    # Construct the redis query.
    query = "XREAD BLOCK "
    # Multiply by 1000x since timeout is in sec and redis expects ms.
    query += str(1000 * timeout)
    query += " STREAMS "
    query += " ".join([_get_task_id(source).hex() for source in sources])
    query += " "
    query += " ".join([
        ray.utils.decode(signal_counters[_get_task_id(source)])
        for source in sources
    ])

    answers = ray.worker.global_worker.redis_client.execute_command(query)
    if not answers:
        return []
    # There will be one answer per source. If there is no signal for a given
    # source, redis will return an empty list for that source.
    assert len(answers) == len(sources)

    results = []
    # Decoding is a little bit involved. Iterate through all the sources:
    for i, answer in enumerate(answers):
        # Make sure the answer corresponds to the source
        assert ray.utils.decode(answer[0]) == _get_task_id(sources[i]).hex()
        # The list of results for that source is stored in answer[1]
        for r in answer[1]:
            # Now it gets tricky: r[0] is the redis internal sequence id
            signal_counters[_get_task_id(sources[i])] = r[0]
            # r[1] contains a list with elements (key, value), in our case
            # we only have one key "signal" and the value is the signal.
            signal = cloudpickle.loads(ray.utils.hex_to_binary(r[1][1]))
            results.append((sources[i], signal))

    return results
コード例 #20
0
 def load_checkpoint(self, checkpoint):
     self.func = cloudpickle.loads(checkpoint["func"])
     self.data = checkpoint["data"]
     self.iter = checkpoint["iter"]
     np.random.set_state(checkpoint["seed"])
コード例 #21
0
ファイル: serialization.py プロジェクト: stjordanis/ray
def _load_ref_helper(key: str, storage: storage.Storage):
    # TODO(Alex): We should stream the data directly into `cloudpickle.load`.
    serialized = asyncio.get_event_loop().run_until_complete(storage.get(key))
    return cloudpickle.loads(serialized)
コード例 #22
0
ファイル: web_server.py プロジェクト: robertnishihara/ray
 def _load_trial_info(self, trial_info):
     trial_info["config"] = cloudpickle.loads(
         hex_to_binary(trial_info["config"]))
     trial_info["result"] = cloudpickle.loads(
         hex_to_binary(trial_info["result"]))
コード例 #23
0
        self.friends = friends
        self.p = Pool(threadCount)

    def __getstate__(self):
        state_dict = self.__dict__.copy()
        # We can't move the threads but we can move the info to make a pool of the same size
        state_dict["p"] = len(self.p._pool)
        return state_dict

    def __setsate__(self):
        self.__dict__.update(state)
        self.p = Pool(self.p)


k = LessBadClass(5, ["boo", "boris"])
pickle.loads(pickle.dumps(k))
#end::custom_serializer[]

# In[ ]:

#tag::custom_serializer_not_own_class[]


def custom_serializer(bad):
    return {"threads": len(bad.p._pool), "friends": bad.friends}


def custom_deserializer(params):
    return BadClass(params["threads"], params["friends"])

コード例 #24
0
ファイル: partition.py プロジェクト: yuishihara/ray
def deserialize(partition_bytes):
    """Deserialize the binary partition function serialized by
    :func:`serialize`"""
    return cloudpickle.loads(partition_bytes)
コード例 #25
0
def deserialize(func_bytes):
    """Deserialize a binary function serialized by `serialize` method."""
    return cloudpickle.loads(func_bytes)
コード例 #26
0
 def _lazy_deserialize(self):
     """
     This should be called after ray has been initialized.
     """
     assert ray.is_initialized()
     self._block_holder = rpickle.loads(self._serialized_data)
コード例 #27
0
def receive(sources, timeout=None):
    """Get all outstanding signals from sources.

    A source can be either (1) an object ID returned by the task (we want
    to receive signals from), or (2) an actor handle.

    When invoked by the same entity E (where E can be an actor, task or
    driver), for each source S in sources, this function returns all signals
    generated by S since the last receive() was invoked by E on S. If this is
    the first call on S, this function returns all past signals generated by S
    so far. Note that different actors, tasks or drivers that call receive()
    on the same source S will get independent copies of the signals generated
    by S.

    Args:
        sources: List of sources from which the caller waits for signals.
            A source is either an object ID returned by a task (in this case
            the object ID is used to identify that task), or an actor handle.
            If the user passes the IDs of multiple objects returned by the
            same task, this function returns a copy of the signals generated
            by that task for each object ID.
        timeout: Maximum time (in seconds) this function waits to get a signal
            from a source in sources. If None, the timeout is infinite.

    Returns:
        A list of pairs (S, sig), where S is a source in the sources argument,
            and sig is a signal generated by S since the last time receive()
            was called on S. Thus, for each S in sources, the return list can
            contain zero or multiple entries.
    """

    # If None, initialize the timeout to a huge value (i.e., over 30,000 years
    # in this case) to "approximate" infinity.
    if timeout is None:
        timeout = 10**12

    if timeout < 0:
        raise ValueError("The 'timeout' argument cannot be less than 0.")

    if not hasattr(ray.worker.global_worker, "signal_counters"):
        ray.worker.global_worker.signal_counters = defaultdict(lambda: b"0")

    signal_counters = ray.worker.global_worker.signal_counters

    # Map the ID of each source task to the source itself.
    task_id_to_sources = defaultdict(lambda: [])
    for s in sources:
        task_id_to_sources[_get_task_id(s).hex()].append(s)

    # Construct the redis query.
    query = "XREAD BLOCK "
    # Multiply by 1000x since timeout is in sec and redis expects ms.
    query += str(1000 * timeout)
    query += " STREAMS "
    query += " ".join([task_id for task_id in task_id_to_sources])
    query += " "
    query += " ".join([
        ray.utils.decode(signal_counters[ray.utils.hex_to_binary(task_id)])
        for task_id in task_id_to_sources
    ])

    answers = ray.worker.global_worker.redis_client.execute_command(query)
    if not answers:
        return []

    results = []
    # Decoding is a little bit involved. Iterate through all the answers:
    for i, answer in enumerate(answers):
        # Make sure the answer corresponds to a source, s, in sources.
        task_id = ray.utils.decode(answer[0])
        task_source_list = task_id_to_sources[task_id]
        # The list of results for source s is stored in answer[1]
        for r in answer[1]:
            for s in task_source_list:
                # Now it gets tricky: r[0] is the redis internal sequence id
                signal_counters[ray.utils.hex_to_binary(task_id)] = r[0]
                # r[1] contains a list with elements (key, value), in our case
                # we only have one key "signal" and the value is the signal.
                signal = cloudpickle.loads(ray.utils.hex_to_binary(r[1][1]))
                results.append((s, signal))

    return results
コード例 #28
0
    def fetch_and_register_actor(self, actor_class_key):
        """Import an actor.

        This will be called by the worker's import thread when the worker
        receives the actor_class export, assuming that the worker is an actor
        for that class.

        Args:
            actor_class_key: The key in Redis to use to fetch the actor.
            worker: The worker to use.
        """
        actor_id = self._worker.actor_id
        (driver_id_str, class_name, module, pickled_class, checkpoint_interval,
         actor_method_names) = self._worker.redis_client.hmget(
             actor_class_key, [
                 "driver_id", "class_name", "module", "class",
                 "checkpoint_interval", "actor_method_names"
             ])

        class_name = decode(class_name)
        module = decode(module)
        driver_id = ray.DriverID(driver_id_str)
        checkpoint_interval = int(checkpoint_interval)
        actor_method_names = json.loads(decode(actor_method_names))

        # In Python 2, json loads strings as unicode, so convert them back to
        # strings.
        if sys.version_info < (3, 0):
            actor_method_names = [
                method_name.encode("ascii")
                for method_name in actor_method_names
            ]

        # Create a temporary actor with some temporary methods so that if
        # the actor fails to be unpickled, the temporary actor can be used
        # (just to produce error messages and to prevent the driver from
        # hanging).
        class TemporaryActor(object):
            pass

        self._worker.actors[actor_id] = TemporaryActor()
        self._worker.actor_checkpoint_interval = checkpoint_interval

        def temporary_actor_method(*xs):
            raise Exception(
                "The actor with name {} failed to be imported, "
                "and so cannot execute this method".format(class_name))

        # Register the actor method executors.
        for actor_method_name in actor_method_names:
            function_descriptor = FunctionDescriptor(module, actor_method_name,
                                                     class_name)
            function_id = function_descriptor.function_id
            temporary_executor = self._make_actor_method_executor(
                actor_method_name,
                temporary_actor_method,
                actor_imported=False)
            self._function_execution_info[driver_id][function_id] = (
                FunctionExecutionInfo(
                    function=temporary_executor,
                    function_name=actor_method_name,
                    max_calls=0))
            self._num_task_executions[driver_id][function_id] = 0

        try:
            unpickled_class = pickle.loads(pickled_class)
            self._worker.actor_class = unpickled_class
        except Exception:
            # If an exception was thrown when the actor was imported, we record
            # the traceback and notify the scheduler of the failure.
            traceback_str = ray.utils.format_error_message(
                traceback.format_exc())
            # Log the error message.
            push_error_to_driver(
                self._worker,
                ray_constants.REGISTER_ACTOR_PUSH_ERROR,
                traceback_str,
                driver_id,
                data={"actor_id": actor_id.binary()})
            # TODO(rkn): In the future, it might make sense to have the worker
            # exit here. However, currently that would lead to hanging if
            # someone calls ray.get on a method invoked on the actor.
        else:
            # TODO(pcm): Why is the below line necessary?
            unpickled_class.__module__ = module
            self._worker.actors[actor_id] = unpickled_class.__new__(
                unpickled_class)

            actor_methods = inspect.getmembers(
                unpickled_class, predicate=is_function_or_method)
            for actor_method_name, actor_method in actor_methods:
                function_descriptor = FunctionDescriptor(
                    module, actor_method_name, class_name)
                function_id = function_descriptor.function_id
                executor = self._make_actor_method_executor(
                    actor_method_name, actor_method, actor_imported=True)
                self._function_execution_info[driver_id][function_id] = (
                    FunctionExecutionInfo(
                        function=executor,
                        function_name=actor_method_name,
                        max_calls=0))
コード例 #29
0
ファイル: function_manager.py プロジェクト: wuisawesome/ray
    def fetch_and_register_remote_function(self, key):
        """Import a remote function."""
        vals = self._worker.gcs_client.internal_kv_get(
            key, KV_NAMESPACE_FUNCTION_TABLE)
        if vals is None:
            vals = {}
        else:
            vals = pickle.loads(vals)
        fields = [
            "job_id",
            "function_id",
            "function_name",
            "function",
            "module",
            "max_calls",
        ]
        (
            job_id_str,
            function_id_str,
            function_name,
            serialized_function,
            module,
            max_calls,
        ) = (vals.get(field) for field in fields)

        function_id = ray.FunctionID(function_id_str)
        job_id = ray.JobID(job_id_str)
        max_calls = int(max_calls)

        # This function is called by ImportThread. This operation needs to be
        # atomic. Otherwise, there is race condition. Another thread may use
        # the temporary function above before the real function is ready.
        with self.lock:
            self._num_task_executions[function_id] = 0

            try:
                function = pickle.loads(serialized_function)
            except Exception:

                # If an exception was thrown when the remote function was
                # imported, we record the traceback and notify the scheduler
                # of the failure.
                traceback_str = format_error_message(traceback.format_exc())

                def f(*args, **kwargs):
                    raise RuntimeError(
                        "The remote function failed to import on the "
                        "worker. This may be because needed library "
                        "dependencies are not installed in the worker "
                        "environment:\n\n{}".format(traceback_str))

                # Use a placeholder method when function pickled failed
                self._function_execution_info[
                    function_id] = FunctionExecutionInfo(
                        function=f,
                        function_name=function_name,
                        max_calls=max_calls)

                # Log the error message. Log at DEBUG level to avoid overly
                # spamming the log on import failure. The user gets the error
                # via the RuntimeError message above.
                logger.debug("Failed to unpickle the remote function "
                             f"'{function_name}' with "
                             f"function ID {function_id.hex()}. "
                             f"Job ID:{job_id}."
                             f"Traceback:\n{traceback_str}. ")
            else:
                # The below line is necessary. Because in the driver process,
                # if the function is defined in the file where the python
                # script was started from, its module is `__main__`.
                # However in the worker process, the `__main__` module is a
                # different module, which is `default_worker.py`
                function.__module__ = module
                self._function_execution_info[
                    function_id] = FunctionExecutionInfo(
                        function=function,
                        function_name=function_name,
                        max_calls=max_calls)
コード例 #30
0
    def __init__(self, checkpoint: bytes = None):
        self.backends: Dict[BackendTag, BackendInfo] = dict()

        if checkpoint is not None:
            self.backends = pickle.loads(checkpoint)
コード例 #31
0
    def fetch_and_register_remote_function(self, key):
        """Import a remote function."""
        (job_id_str, function_id_str, function_name, serialized_function,
         num_return_vals, module, resources,
         max_calls) = self._worker.redis_client.hmget(key, [
             "job_id", "function_id", "function_name", "function",
             "num_return_vals", "module", "resources", "max_calls"
         ])
        function_id = ray.FunctionID(function_id_str)
        job_id = ray.JobID(job_id_str)
        function_name = decode(function_name)
        max_calls = int(max_calls)
        module = decode(module)

        # This is a placeholder in case the function can't be unpickled. This
        # will be overwritten if the function is successfully registered.
        def f(*args, **kwargs):
            raise Exception("This function was not imported properly.")

        # This function is called by ImportThread. This operation needs to be
        # atomic. Otherwise, there is race condition. Another thread may use
        # the temporary function above before the real function is ready.
        with self.lock:
            self._function_execution_info[job_id][function_id] = (
                FunctionExecutionInfo(
                    function=f,
                    function_name=function_name,
                    max_calls=max_calls))
            self._num_task_executions[job_id][function_id] = 0

            try:
                function = pickle.loads(serialized_function)
            except Exception:
                # If an exception was thrown when the remote function was
                # imported, we record the traceback and notify the scheduler
                # of the failure.
                traceback_str = format_error_message(traceback.format_exc())
                # Log the error message.
                push_error_to_driver(
                    self._worker,
                    ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
                    "Failed to unpickle the remote function '{}' with "
                    "function ID {}. Traceback:\n{}".format(
                        function_name, function_id.hex(), traceback_str),
                    job_id=job_id)
            else:
                # The below line is necessary. Because in the driver process,
                # if the function is defined in the file where the python
                # script was started from, its module is `__main__`.
                # However in the worker process, the `__main__` module is a
                # different module, which is `default_worker.py`
                function.__module__ = module
                self._function_execution_info[job_id][function_id] = (
                    FunctionExecutionInfo(
                        function=function,
                        function_name=function_name,
                        max_calls=max_calls))
                # Add the function to the function table.
                self._worker.redis_client.rpush(
                    b"FunctionTable:" + function_id.binary(),
                    self._worker.worker_id)
コード例 #32
0
ファイル: kv_store_service.py プロジェクト: yuishihara/ray
 def get_init_args(self, backend_tag):
     return pickle.loads(self.backend_init_args.get(backend_tag))
コード例 #33
0
ファイル: router.py プロジェクト: zommiommy/ray
 def ray_deserialize(value):
     kwargs = pickle.loads(value)
     return Query(**kwargs)
コード例 #34
0
ファイル: kv_store_service.py プロジェクト: yuishihara/ray
 def get_backend_creator(self, backend_tag):
     return pickle.loads(self.backend_table.get(backend_tag))
コード例 #35
0
 def id_deserializer(serialized_obj):
     return pickle.loads(serialized_obj)
コード例 #36
0
        async def __init__(
            self,
            deployment_name,
            replica_tag,
            serialized_deployment_def: bytes,
            serialized_init_args: bytes,
            serialized_init_kwargs: bytes,
            deployment_config_proto_bytes: bytes,
            version: DeploymentVersion,
            controller_name: str,
            controller_namespace: str,
            detached: bool,
        ):
            configure_component_logger(
                component_type="deployment",
                component_name=deployment_name,
                component_id=replica_tag,
            )

            deployment_def = cloudpickle.loads(serialized_deployment_def)

            if isinstance(deployment_def, str):
                import_path = deployment_def
                module_name, attr_name = parse_import_path(import_path)
                deployment_def = getattr(import_module(module_name), attr_name)
                # For ray or serve decorated class or function, strip to return
                # original body
                if isinstance(deployment_def, RemoteFunction):
                    deployment_def = deployment_def._function
                elif isinstance(deployment_def, ActorClass):
                    deployment_def = deployment_def.__ray_metadata__.modified_class
                elif isinstance(deployment_def, Deployment):
                    logger.warning(
                        f'The import path "{import_path}" contains a '
                        "decorated Serve deployment. The decorator's settings "
                        "are ignored when deploying via import path.")
                    deployment_def = deployment_def.func_or_class

            init_args = cloudpickle.loads(serialized_init_args)
            init_kwargs = cloudpickle.loads(serialized_init_kwargs)

            deployment_config = DeploymentConfig.from_proto_bytes(
                deployment_config_proto_bytes)

            if inspect.isfunction(deployment_def):
                is_function = True
            elif inspect.isclass(deployment_def):
                is_function = False
            else:
                assert False, (
                    "deployment_def must be function, class, or "
                    "corresponding import path. Instead, it's type was "
                    f"{type(deployment_def)}.")

            # Set the controller name so that serve.connect() in the user's
            # code will connect to the instance that this deployment is running
            # in.
            ray.serve.context.set_internal_replica_context(
                deployment_name,
                replica_tag,
                controller_name,
                controller_namespace,
                servable_object=None,
            )

            assert controller_name, "Must provide a valid controller_name"

            controller_handle = ray.get_actor(controller_name,
                                              namespace=controller_namespace)

            # This closure initializes user code and finalizes replica
            # startup. By splitting the initialization step like this,
            # we can already access this actor before the user code
            # has finished initializing.
            # The supervising state manager can then wait
            # for allocation of this replica by using the `is_allocated`
            # method. After that, it calls `reconfigure` to trigger
            # user code initialization.
            async def initialize_replica():
                if is_function:
                    _callable = deployment_def
                else:
                    # This allows deployments to define an async __init__
                    # method (required for FastAPI).
                    _callable = deployment_def.__new__(deployment_def)
                    await sync_to_async(_callable.__init__)(*init_args,
                                                            **init_kwargs)

                # Setting the context again to update the servable_object.
                ray.serve.context.set_internal_replica_context(
                    deployment_name,
                    replica_tag,
                    controller_name,
                    controller_namespace,
                    servable_object=_callable,
                )

                self.replica = RayServeReplica(
                    _callable,
                    deployment_name,
                    replica_tag,
                    deployment_config,
                    deployment_config.user_config,
                    version,
                    is_function,
                    controller_handle,
                )

            # Is it fine that replica is None here?
            # Should we add a check in all methods that use self.replica
            # or, alternatively, create an async get_replica() method?
            self.replica = None
            self._initialize_replica = initialize_replica
コード例 #37
0
 def ray_deserialize(value):
     kwargs = pickle.loads(value)
     return RequestMetadata(**kwargs)
コード例 #38
0
    async def _recover_from_checkpoint(self, checkpoint_bytes):
        """Recover the instance state from the provided checkpoint.

        Performs the following operations:
            1) Deserializes the internal state from the checkpoint.
            2) Pushes the latest configuration to the HTTP proxy and router
               in case we crashed before updating them.
            3) Starts/stops any worker replicas that are pending creation or
               deletion.

        NOTE: this requires that self.write_lock is already acquired and will
        release it before returning.
        """
        assert self.write_lock.locked()

        start = time.time()
        logger.info("Recovering from checkpoint")

        # Load internal state from the checkpoint data.
        (
            self.routes,
            self.backends,
            self.traffic_policies,
            self.replicas,
            self.replicas_to_start,
            self.replicas_to_stop,
            self.backends_to_remove,
            self.endpoints_to_remove,
        ) = pickle.loads(checkpoint_bytes)

        # Fetch actor handles for all of the backend replicas in the system.
        # All of these workers are guaranteed to already exist because they
        # would not be written to a checkpoint in self.workers until they
        # were created.
        for backend_tag, replica_tags in self.replicas.items():
            for replica_tag in replica_tags:
                replica_name = format_actor_name(replica_tag,
                                                 self.instance_name)
                self.workers[backend_tag][replica_tag] = ray.get_actor(
                    replica_name)

        # Push configuration state to the router.
        # TODO(edoakes): should we make this a pull-only model for simplicity?
        for endpoint, traffic_policy in self.traffic_policies.items():
            await self.router.set_traffic.remote(endpoint, traffic_policy)

        for backend_tag, replica_dict in self.workers.items():
            for replica_tag, worker in replica_dict.items():
                await self.router.add_new_worker.remote(
                    backend_tag, replica_tag, worker)

        for backend, (_, backend_config, _) in self.backends.items():
            await self.router.set_backend_config.remote(
                backend, backend_config)

        # Push configuration state to the HTTP proxy.
        await self.http_proxy.set_route_table.remote(self.routes)

        # Start/stop any pending backend replicas.
        await self._start_pending_replicas()
        await self._stop_pending_replicas()

        # Remove any pending backends and endpoints.
        await self._remove_pending_backends()
        await self._remove_pending_endpoints()

        logger.info("Recovered from checkpoint in {:.3f}s".format(time.time() -
                                                                  start))

        self.write_lock.release()
コード例 #39
0
ファイル: server.py プロジェクト: tseiger1/ray
 def PutObject(self, request, context=None) -> ray_client_pb2.PutResponse:
     obj = cloudpickle.loads(request.data)
     objectref = self._put_and_retain_obj(obj)
     pickled_ref = cloudpickle.dumps(objectref)
     return ray_client_pb2.PutResponse(
         ref=make_remote_ref(objectref.binary(), pickled_ref))
コード例 #40
0
    async def __init__(self,
                       controller_name: str,
                       http_host: str,
                       http_port: str,
                       http_middlewares: List[Any],
                       detached: bool = False):
        # Used to read/write checkpoints.
        self.kv_store = RayInternalKVStore(namespace=controller_name)
        self.actor_reconciler = ActorStateReconciler(controller_name, detached)

        # backend -> AutoscalingPolicy
        self.autoscaling_policies = dict()

        # Dictionary of backend_tag -> proxy_name -> most recent queue length.
        self.backend_stats = defaultdict(lambda: defaultdict(dict))

        # Used to ensure that only a single state-changing operation happens
        # at any given time.
        self.write_lock = asyncio.Lock()

        self.http_host = http_host
        self.http_port = http_port
        self.http_middlewares = http_middlewares

        # If starting the actor for the first time, starts up the other system
        # components. If recovering, fetches their actor handles.
        self.actor_reconciler._start_http_proxies_if_needed(
            self.http_host, self.http_port, self.http_middlewares)

        # Map of awaiting results
        # TODO(ilr): Checkpoint this once this becomes asynchronous
        self.inflight_results: Dict[UUID, asyncio.Event] = dict()
        self._serializable_inflight_results: Dict[UUID, FutureResult] = dict()

        checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY)
        if checkpoint_bytes is None:
            logger.debug("No checkpoint found")
            self.backend_state = BackendState()
            self.endpoint_state = EndpointState()
        else:
            checkpoint: Checkpoint = pickle.loads(checkpoint_bytes)
            self.backend_state = BackendState(
                checkpoint=checkpoint.backend_state_checkpoint)
            self.endpoint_state = EndpointState(
                checkpoint=checkpoint.endpoint_state_checkpoint)
            await self._recover_from_checkpoint(checkpoint)

        # NOTE(simon): Currently we do all-to-all broadcast. This means
        # any listeners will receive notification for all changes. This
        # can be problem at scale, e.g. updating a single backend config
        # will send over the entire configs. In the future, we should
        # optimize the logic to support subscription by key.
        self.long_poll_host = LongPollHost()

        # The configs pushed out here get updated by
        # self._recover_from_checkpoint in the failure scenario, so that must
        # be run before we notify the changes.
        self.notify_backend_configs_changed()
        self.notify_replica_handles_changed()
        self.notify_traffic_policies_changed()
        self.notify_route_table_changed()

        asyncio.get_event_loop().create_task(self.run_control_loop())