示例#1
0
    def test_dumps_loads_return_with_custom_class(self):
        class A(object):
            def __init__(self):
                self.a = 3

        a = A()
        serialize_bytes = dumps_return(a)
        deserialize_result = loads_return(serialize_bytes)

        assert deserialize_result.a == 3
示例#2
0
                def wrapper(*args, **kwargs):
                    if self.job_shutdown:
                        raise RemoteError(
                            attr, "This actor losts connection with the job.")
                    self.internal_lock.acquire()
                    data = dumps_argument(*args, **kwargs)

                    self.job_socket.send_multipart(
                        [remote_constants.CALL_TAG,
                         to_byte(attr), data])

                    message = self.job_socket.recv_multipart()
                    tag = message[0]

                    if tag == remote_constants.NORMAL_TAG:
                        ret = loads_return(message[1])

                    elif tag == remote_constants.EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteError(attr, error_str)

                    elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteAttributeError(attr, error_str)

                    elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteSerializeError(attr, error_str)

                    elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteDeserializeError(attr, error_str)

                    else:
                        self.job_shutdown = True
                        raise NotImplementedError()

                    self.internal_lock.release()
                    return ret
示例#3
0
    def test_speed_of_dumps_loads_return(self):
        data1 = {
            i: {
                100 * str(j): [set(range(100)) for _ in range(10)]
                for j in range(100)
            }
            for i in range(100)
        }

        data2 = [np.random.RandomState(0).randn(100, 300)] * 5

        data3 = [np.random.RandomState(0).randn(400, 100, 3000)]

        for i, data in enumerate([data1, data2, data3]):
            start = time.time()
            for _ in range(10):
                serialize_bytes = dumps_return(data)
                deserialize_result = loads_return(serialize_bytes)
            print('Case {}, Average dump and load return time:'.format(i),
                  (time.time() - start) / 10)
示例#4
0
 def run(i):
     a = A(i)
     serialize_bytes = dumps_return(a)
     deserialize_result = loads_return(serialize_bytes)
     assert deserialize_result.a == i
示例#5
0
            def __getattr__(self, attr):
                """Call the function of the unwrapped class."""
                #check if attr is a function or not
                if attr in cls().__dict__:
                    self.internal_lock.acquire()
                    self.job_socket.send_multipart(
                        [remote_constants.GET_ATTRIBUTE,
                         to_byte(attr)])
                    message = self.job_socket.recv_multipart()
                    tag = message[0]

                    if tag == remote_constants.NORMAL_TAG:
                        ret = loads_return(message[1])
                        self.internal_lock.release()
                        return ret
                    elif tag == remote_constants.EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteError(attr, error_str)

                    elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteAttributeError(attr, error_str)

                    elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteSerializeError(attr, error_str)

                    elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteDeserializeError(attr, error_str)

                    else:
                        self.job_shutdown = True
                        raise NotImplementedError()

                def wrapper(*args, **kwargs):
                    if self.job_shutdown:
                        raise RemoteError(
                            attr, "This actor losts connection with the job.")
                    self.internal_lock.acquire()
                    data = dumps_argument(*args, **kwargs)

                    self.job_socket.send_multipart(
                        [remote_constants.CALL_TAG,
                         to_byte(attr), data])

                    message = self.job_socket.recv_multipart()
                    tag = message[0]

                    if tag == remote_constants.NORMAL_TAG:
                        ret = loads_return(message[1])

                    elif tag == remote_constants.EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteError(attr, error_str)

                    elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteAttributeError(attr, error_str)

                    elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteSerializeError(attr, error_str)

                    elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteDeserializeError(attr, error_str)

                    else:
                        self.job_shutdown = True
                        raise NotImplementedError()

                    self.internal_lock.release()
                    return ret

                return wrapper
示例#6
0
文件: job.py 项目: YuechengLiu/PARL
    def single_task(self, obj, reply_socket, job_address):
        """An infinite loop waiting for commands from the remote object.

        Each job will receive two kinds of message from the remote object:

        1. When the remote object calls a function, job will run the
           function on the local instance and return the results to the
           remote object.
        2. When the remote object is deleted, the job will quit and release
           related computation resources.

        Args:
            reply_socket (sockert): main socket to accept commands of remote object.
            job_address (String): address of reply_socket.
        """

        while True:
            message = reply_socket.recv_multipart()

            tag = message[0]

            if tag in [
                    remote_constants.CALL_TAG, remote_constants.GET_ATTRIBUTE,
                    remote_constants.SET_ATTRIBUTE
            ]:
                # if tag == remote_constants.CALL_TAG:
                try:
                    if tag == remote_constants.CALL_TAG:
                        function_name = to_str(message[1])
                        data = message[2]
                        args, kwargs = loads_argument(data)

                        # Redirect stdout to stdout.log temporarily
                        logfile_path = os.path.join(self.log_dir, 'stdout.log')
                        with redirect_stdout_to_file(logfile_path):
                            ret = getattr(obj, function_name)(*args, **kwargs)

                        ret = dumps_return(ret)

                        reply_socket.send_multipart(
                            [remote_constants.NORMAL_TAG, ret])

                    elif tag == remote_constants.GET_ATTRIBUTE:
                        attribute_name = to_str(message[1])
                        logfile_path = os.path.join(self.log_dir, 'stdout.log')
                        with redirect_stdout_to_file(logfile_path):
                            ret = getattr(obj, attribute_name)
                        ret = dumps_return(ret)
                        reply_socket.send_multipart(
                            [remote_constants.NORMAL_TAG, ret])
                    else:
                        attribute_name = to_str(message[1])
                        attribute_value = loads_return(message[2])
                        logfile_path = os.path.join(self.log_dir, 'stdout.log')
                        with redirect_stdout_to_file(logfile_path):
                            setattr(obj, attribute_name, attribute_value)
                        reply_socket.send_multipart(
                            [remote_constants.NORMAL_TAG])

                except Exception as e:
                    # reset the job

                    error_str = str(e)
                    logger.error(error_str)

                    if type(e) == AttributeError:
                        reply_socket.send_multipart([
                            remote_constants.ATTRIBUTE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
                        raise AttributeError

                    elif type(e) == SerializeError:
                        reply_socket.send_multipart([
                            remote_constants.SERIALIZE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
                        raise SerializeError

                    elif type(e) == DeserializeError:
                        reply_socket.send_multipart([
                            remote_constants.DESERIALIZE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
                        raise DeserializeError

                    else:
                        traceback_str = str(traceback.format_exc())
                        logger.error("traceback:\n{}".format(traceback_str))
                        reply_socket.send_multipart([
                            remote_constants.EXCEPTION_TAG,
                            to_byte(error_str + "\ntraceback:\n" +
                                    traceback_str)
                        ])
                        break

            # receive DELETE_TAG from actor, and stop replying worker heartbeat
            elif tag == remote_constants.KILLJOB_TAG:
                reply_socket.send_multipart([remote_constants.NORMAL_TAG])
                logger.warning(
                    "An actor exits and this job {} will exit.".format(
                        job_address))
                break
            else:
                logger.error(
                    "The job receives an unknown message: {}".format(message))
                raise NotImplementedError