Esempio n. 1
0
    def fetch_and_register_remote_function(self, key):
        """Import a remote function."""
        (job_id_str, function_id_str, function_name, serialized_function,
         module,
         max_calls) = self._worker.redis_client.hmget(key, [
             "job_id", "function_id", "function_name", "function", "module",
             "max_calls"
         ])
        function_id = ray.FunctionID(function_id_str)
        job_id = ray.JobID(job_id_str)
        function_name = decode(function_name)
        max_calls = int(max_calls)
        module = decode(module)

        # This function is called by ImportThread. This operation needs to be
        # atomic. Otherwise, there is race condition. Another thread may use
        # the temporary function above before the real function is ready.
        with self.lock:
            self._num_task_executions[job_id][function_id] = 0

            try:
                function = pickle.loads(serialized_function)
            except Exception:

                def f(*args, **kwargs):
                    raise RuntimeError(
                        "This function was not imported properly.")

                # Use a placeholder method when function pickled failed
                self._function_execution_info[job_id][function_id] = (
                    FunctionExecutionInfo(function=f,
                                          function_name=function_name,
                                          max_calls=max_calls))
                # If an exception was thrown when the remote function was
                # imported, we record the traceback and notify the scheduler
                # of the failure.
                traceback_str = format_error_message(traceback.format_exc())
                # Log the error message.
                push_error_to_driver(
                    self._worker,
                    ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
                    "Failed to unpickle the remote function '{}' with "
                    "function ID {}. Traceback:\n{}".format(
                        function_name, function_id.hex(), traceback_str),
                    job_id=job_id)
            else:
                # The below line is necessary. Because in the driver process,
                # if the function is defined in the file where the python
                # script was started from, its module is `__main__`.
                # However in the worker process, the `__main__` module is a
                # different module, which is `default_worker.py`
                function.__module__ = module
                self._function_execution_info[job_id][function_id] = (
                    FunctionExecutionInfo(function=function,
                                          function_name=function_name,
                                          max_calls=max_calls))
                # Add the function to the function table.
                self._worker.redis_client.rpush(
                    b"FunctionTable:" + function_id.binary(),
                    self._worker.worker_id)
Esempio n. 2
0
    def fetch_and_register_remote_function(self, key):
        """Import a remote function."""
        (driver_id_str, function_id_str, function_name, serialized_function,
         num_return_vals, module, resources,
         max_calls) = self._worker.redis_client.hmget(key, [
             "driver_id", "function_id", "name", "function", "num_return_vals",
             "module", "resources", "max_calls"
         ])
        function_id = ray.FunctionID(function_id_str)
        driver_id = ray.DriverID(driver_id_str)
        function_name = decode(function_name)
        max_calls = int(max_calls)
        module = decode(module)

        # This is a placeholder in case the function can't be unpickled. This
        # will be overwritten if the function is successfully registered.
        def f():
            raise Exception("This function was not imported properly.")

        self._function_execution_info[driver_id][function_id] = (
            FunctionExecutionInfo(
                function=f, function_name=function_name, max_calls=max_calls))
        self._num_task_executions[driver_id][function_id] = 0

        try:
            function = pickle.loads(serialized_function)
        except Exception:
            # If an exception was thrown when the remote function was imported,
            # we record the traceback and notify the scheduler of the failure.
            traceback_str = format_error_message(traceback.format_exc())
            # Log the error message.
            push_error_to_driver(
                self._worker,
                ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
                traceback_str,
                driver_id=driver_id,
                data={
                    "function_id": function_id.binary(),
                    "function_name": function_name
                })
        else:
            # The below line is necessary. Because in the driver process,
            # if the function is defined in the file where the python script
            # was started from, its module is `__main__`.
            # However in the worker process, the `__main__` module is a
            # different module, which is `default_worker.py`
            function.__module__ = module
            self._function_execution_info[driver_id][function_id] = (
                FunctionExecutionInfo(
                    function=function,
                    function_name=function_name,
                    max_calls=max_calls))
            # Add the function to the function table.
            self._worker.redis_client.rpush(
                b"FunctionTable:" + function_id.binary(),
                self._worker.worker_id)
Esempio n. 3
0
    def fetch_and_register_remote_function(self, key):
        """Import a remote function."""
        (driver_id_str, function_id_str, function_name, serialized_function,
         num_return_vals, module, resources,
         max_calls) = self._worker.redis_client.hmget(key, [
             "driver_id", "function_id", "name", "function", "num_return_vals",
             "module", "resources", "max_calls"
         ])
        function_id = ray.FunctionID(function_id_str)
        driver_id = ray.DriverID(driver_id_str)
        function_name = decode(function_name)
        max_calls = int(max_calls)
        module = decode(module)

        # This is a placeholder in case the function can't be unpickled. This
        # will be overwritten if the function is successfully registered.
        def f():
            raise Exception("This function was not imported properly.")

        self._function_execution_info[driver_id][function_id] = (
            FunctionExecutionInfo(
                function=f, function_name=function_name, max_calls=max_calls))
        self._num_task_executions[driver_id][function_id] = 0

        try:
            function = pickle.loads(serialized_function)
        except Exception:
            # If an exception was thrown when the remote function was imported,
            # we record the traceback and notify the scheduler of the failure.
            traceback_str = format_error_message(traceback.format_exc())
            # Log the error message.
            push_error_to_driver(
                self._worker,
                ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
                "Failed to unpickle the remote function '{}' with function ID "
                "{}. Traceback:\n{}".format(function_name, function_id.hex(),
                                            traceback_str),
                driver_id=driver_id)
        else:
            # The below line is necessary. Because in the driver process,
            # if the function is defined in the file where the python script
            # was started from, its module is `__main__`.
            # However in the worker process, the `__main__` module is a
            # different module, which is `default_worker.py`
            function.__module__ = module
            self._function_execution_info[driver_id][function_id] = (
                FunctionExecutionInfo(
                    function=function,
                    function_name=function_name,
                    max_calls=max_calls))
            # Add the function to the function table.
            self._worker.redis_client.rpush(
                b"FunctionTable:" + function_id.binary(),
                self._worker.worker_id)
Esempio n. 4
0
    def fetch_and_register_remote_function(self, key):
        """Import a remote function."""
        from ray.worker import FunctionExecutionInfo
        (driver_id, function_id_str, function_name, serialized_function,
         num_return_vals, module, resources,
         max_calls) = self.redis_client.hmget(key, [
             "driver_id", "function_id", "name", "function", "num_return_vals",
             "module", "resources", "max_calls"
         ])
        function_id = ray.ObjectID(function_id_str)
        function_name = utils.decode(function_name)
        max_calls = int(max_calls)
        module = utils.decode(module)

        # This is a placeholder in case the function can't be unpickled. This
        # will be overwritten if the function is successfully registered.
        def f():
            raise Exception("This function was not imported properly.")

        self.worker.function_execution_info[driver_id][function_id.id()] = (
            FunctionExecutionInfo(function=f,
                                  function_name=function_name,
                                  max_calls=max_calls))
        self.worker.num_task_executions[driver_id][function_id.id()] = 0

        try:
            function = pickle.loads(serialized_function)
        except Exception:
            # If an exception was thrown when the remote function was imported,
            # we record the traceback and notify the scheduler of the failure.
            traceback_str = utils.format_error_message(traceback.format_exc())
            # Log the error message.
            utils.push_error_to_driver(
                self.worker,
                ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
                traceback_str,
                driver_id=driver_id,
                data={
                    "function_id": function_id.id(),
                    "function_name": function_name
                })
        else:
            # TODO(rkn): Why is the below line necessary?
            function.__module__ = module
            self.worker.function_execution_info[driver_id][
                function_id.id()] = (FunctionExecutionInfo(
                    function=function,
                    function_name=function_name,
                    max_calls=max_calls))
            # Add the function to the function table.
            self.redis_client.rpush(b"FunctionTable:" + function_id.id(),
                                    self.worker.worker_id)
Esempio n. 5
0
    def execute(self, function, function_name, args, kwargs, num_return_vals):
        """Synchronously executes a "remote" function or actor method.

        Stores results directly in the generated and returned
        LocalModeObjectIDs. Any exceptions raised during function execution
        will be stored under all returned object IDs and later raised by the
        worker.

        Args:
            function: The function to execute.
            function_name: Name of the function to execute.
            args: Arguments to the function. These will not be modified by
                the function execution.
            kwargs: Keyword arguments to the function.
            num_return_vals: Number of expected return values specified in the
                function's decorator.

        Returns:
            LocalModeObjectIDs corresponding to the function return values.
        """
        return_ids = [
            LocalModeObjectID.from_random() for _ in range(num_return_vals)
        ]
        new_args = []
        for i, arg in enumerate(args):
            if isinstance(arg, ObjectID):
                new_args.append(ray.get(arg))
            else:
                new_args.append(copy.deepcopy(arg))

        new_kwargs = {}
        for k, v in kwargs.items():
            if isinstance(v, ObjectID):
                new_kwargs[k] = ray.get(v)
            else:
                new_kwargs[k] = copy.deepcopy(v)

        try:
            results = function(*new_args, **new_kwargs)
            if num_return_vals == 1:
                return_ids[0].value = results
            else:
                for object_id, result in zip(return_ids, results):
                    object_id.value = result
        except Exception as e:
            backtrace = format_error_message(traceback.format_exc())
            task_error = RayTaskError(function_name, backtrace, e.__class__)
            for object_id in return_ids:
                object_id.value = task_error

        return return_ids
Esempio n. 6
0
    def execute(self, function, function_descriptor, args, kwargs,
                num_return_vals):
        """Synchronously executes a "remote" function or actor method.

        Stores results directly in the generated and returned
        LocalModeObjectIDs. Any exceptions raised during function execution
        will be stored under all returned object IDs and later raised by the
        worker.

        Args:
            function: The function to execute.
            function_descriptor: Metadata about the function.
            args: Arguments to the function. These will not be modified by
                the function execution.
            kwargs: Keyword arguments to the function.
            num_return_vals: Number of expected return values specified in the
                function's decorator.

        Returns:
            LocalModeObjectIDs corresponding to the function return values.
        """
        object_ids = [
            LocalModeObjectID.from_random() for _ in range(num_return_vals)
        ]
        try:
            results = function(*copy.deepcopy(args), **copy.deepcopy(kwargs))
            if num_return_vals == 1:
                object_ids[0].value = results
            else:
                for object_id, result in zip(object_ids, results):
                    object_id.value = result
        except Exception as e:
            function_name = function_descriptor.function_name
            backtrace = format_error_message(traceback.format_exc())
            task_error = RayTaskError(function_name, backtrace, e.__class__)
            for object_id in object_ids:
                object_id.value = task_error

        return object_ids