def test_dumps_loads_return_with_custom_class(self): class A(object): def __init__(self): self.a = 3 a = A() serialize_bytes = dumps_return(a) deserialize_result = loads_return(serialize_bytes) assert deserialize_result.a == 3
def wrapper(*args, **kwargs): if self.job_shutdown: raise RemoteError( attr, "This actor losts connection with the job.") self.internal_lock.acquire() data = dumps_argument(*args, **kwargs) self.job_socket.send_multipart( [remote_constants.CALL_TAG, to_byte(attr), data]) message = self.job_socket.recv_multipart() tag = message[0] if tag == remote_constants.NORMAL_TAG: ret = loads_return(message[1]) elif tag == remote_constants.EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteError(attr, error_str) elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteAttributeError(attr, error_str) elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteSerializeError(attr, error_str) elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteDeserializeError(attr, error_str) else: self.job_shutdown = True raise NotImplementedError() self.internal_lock.release() return ret
def test_speed_of_dumps_loads_return(self): data1 = { i: { 100 * str(j): [set(range(100)) for _ in range(10)] for j in range(100) } for i in range(100) } data2 = [np.random.RandomState(0).randn(100, 300)] * 5 data3 = [np.random.RandomState(0).randn(400, 100, 3000)] for i, data in enumerate([data1, data2, data3]): start = time.time() for _ in range(10): serialize_bytes = dumps_return(data) deserialize_result = loads_return(serialize_bytes) print('Case {}, Average dump and load return time:'.format(i), (time.time() - start) / 10)
def run(i): a = A(i) serialize_bytes = dumps_return(a) deserialize_result = loads_return(serialize_bytes) assert deserialize_result.a == i
def __getattr__(self, attr): """Call the function of the unwrapped class.""" #check if attr is a function or not if attr in cls().__dict__: self.internal_lock.acquire() self.job_socket.send_multipart( [remote_constants.GET_ATTRIBUTE, to_byte(attr)]) message = self.job_socket.recv_multipart() tag = message[0] if tag == remote_constants.NORMAL_TAG: ret = loads_return(message[1]) self.internal_lock.release() return ret elif tag == remote_constants.EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteError(attr, error_str) elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteAttributeError(attr, error_str) elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteSerializeError(attr, error_str) elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteDeserializeError(attr, error_str) else: self.job_shutdown = True raise NotImplementedError() def wrapper(*args, **kwargs): if self.job_shutdown: raise RemoteError( attr, "This actor losts connection with the job.") self.internal_lock.acquire() data = dumps_argument(*args, **kwargs) self.job_socket.send_multipart( [remote_constants.CALL_TAG, to_byte(attr), data]) message = self.job_socket.recv_multipart() tag = message[0] if tag == remote_constants.NORMAL_TAG: ret = loads_return(message[1]) elif tag == remote_constants.EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteError(attr, error_str) elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteAttributeError(attr, error_str) elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteSerializeError(attr, error_str) elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteDeserializeError(attr, error_str) else: self.job_shutdown = True raise NotImplementedError() self.internal_lock.release() return ret return wrapper
def single_task(self, obj, reply_socket, job_address): """An infinite loop waiting for commands from the remote object. Each job will receive two kinds of message from the remote object: 1. When the remote object calls a function, job will run the function on the local instance and return the results to the remote object. 2. When the remote object is deleted, the job will quit and release related computation resources. Args: reply_socket (sockert): main socket to accept commands of remote object. job_address (String): address of reply_socket. """ while True: message = reply_socket.recv_multipart() tag = message[0] if tag in [ remote_constants.CALL_TAG, remote_constants.GET_ATTRIBUTE, remote_constants.SET_ATTRIBUTE ]: # if tag == remote_constants.CALL_TAG: try: if tag == remote_constants.CALL_TAG: function_name = to_str(message[1]) data = message[2] args, kwargs = loads_argument(data) # Redirect stdout to stdout.log temporarily logfile_path = os.path.join(self.log_dir, 'stdout.log') with redirect_stdout_to_file(logfile_path): ret = getattr(obj, function_name)(*args, **kwargs) ret = dumps_return(ret) reply_socket.send_multipart( [remote_constants.NORMAL_TAG, ret]) elif tag == remote_constants.GET_ATTRIBUTE: attribute_name = to_str(message[1]) logfile_path = os.path.join(self.log_dir, 'stdout.log') with redirect_stdout_to_file(logfile_path): ret = getattr(obj, attribute_name) ret = dumps_return(ret) reply_socket.send_multipart( [remote_constants.NORMAL_TAG, ret]) else: attribute_name = to_str(message[1]) attribute_value = loads_return(message[2]) logfile_path = os.path.join(self.log_dir, 'stdout.log') with redirect_stdout_to_file(logfile_path): setattr(obj, attribute_name, attribute_value) reply_socket.send_multipart( [remote_constants.NORMAL_TAG]) except Exception as e: # reset the job error_str = str(e) logger.error(error_str) if type(e) == AttributeError: reply_socket.send_multipart([ remote_constants.ATTRIBUTE_EXCEPTION_TAG, to_byte(error_str) ]) raise AttributeError elif type(e) == SerializeError: reply_socket.send_multipart([ remote_constants.SERIALIZE_EXCEPTION_TAG, to_byte(error_str) ]) raise SerializeError elif type(e) == DeserializeError: reply_socket.send_multipart([ remote_constants.DESERIALIZE_EXCEPTION_TAG, to_byte(error_str) ]) raise DeserializeError else: traceback_str = str(traceback.format_exc()) logger.error("traceback:\n{}".format(traceback_str)) reply_socket.send_multipart([ remote_constants.EXCEPTION_TAG, to_byte(error_str + "\ntraceback:\n" + traceback_str) ]) break # receive DELETE_TAG from actor, and stop replying worker heartbeat elif tag == remote_constants.KILLJOB_TAG: reply_socket.send_multipart([remote_constants.NORMAL_TAG]) logger.warning( "An actor exits and this job {} will exit.".format( job_address)) break else: logger.error( "The job receives an unknown message: {}".format(message)) raise NotImplementedError