def _create_worker_monitor(self, worker_address): """When a new worker connects to the master, a socket is created to send heartbeat signals to the worker. """ worker_heartbeat_socket = self.ctx.socket(zmq.REQ) worker_heartbeat_socket.linger = 0 worker_heartbeat_socket.setsockopt( zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) worker_heartbeat_socket.connect("tcp://" + worker_address) connected = True while connected and self.master_is_alive: try: worker_heartbeat_socket.send_multipart( [remote_constants.HEARTBEAT_TAG]) worker_status = worker_heartbeat_socket.recv_multipart() vacant_cpus = self.job_center.get_vacant_cpu(worker_address) total_cpus = self.job_center.get_total_cpu(worker_address) self.cluster_monitor.update_worker_status( worker_status, worker_address, vacant_cpus, total_cpus) time.sleep(remote_constants.HEARTBEAT_INTERVAL_S) except zmq.error.Again as e: self.job_center.drop_worker(worker_address) self.cluster_monitor.drop_worker_status(worker_address) logger.warning("\n[Master] Cannot connect to the worker " + "{}. ".format(worker_address) + "Worker_pool will drop this worker.") self._print_workers() connected = False except zmq.error.ZMQError as e: break worker_heartbeat_socket.close(0) logger.warning("Exit worker monitor from master.")
def _kill_job(self, job_address): """Kill a job process and update worker information""" success = self.worker_status.remove_job(job_address) if success: while True: initialized_job = self.job_buffer.get() initialized_job.worker_address = self.master_heartbeat_address if initialized_job.is_alive: self.worker_status.add_job(initialized_job) if not initialized_job.is_alive: # make sure that the job is still alive. self.worker_status.remove_job( initialized_job.job_address) continue else: logger.warning( "[Worker] a dead job found. The job buffer will not accept this one." ) if initialized_job.is_alive: break self.lock.acquire() self.request_master_socket.send_multipart([ remote_constants.NEW_JOB_TAG, cloudpickle.dumps(initialized_job), to_byte(job_address) ]) _ = self.request_master_socket.recv_multipart() self.lock.release()
def run(self): """An infinite loop waiting for messages from the workers and clients. Master node will receive four types of messages: 1. A new worker connects to the master node. 2. A connected worker sending new job address after it kills an old job. 3. A new client connects to the master node. 4. A connected client submits a job after a remote object is created. """ self.client_socket.linger = 0 self.client_socket.setsockopt( zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000) while self.master_is_alive: try: self._receive_message() pass except zmq.error.Again as e: #detect whether `self.master_is_alive` is True periodically pass logger.warning("[Master] Exit master.")
def _create_job_monitor(self, job): """Send heartbeat signals to check target's status""" # job_heartbeat_socket: sends heartbeat signal to job job_heartbeat_socket = self.ctx.socket(zmq.REQ) job_heartbeat_socket.linger = 0 job_heartbeat_socket.setsockopt( zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) job_heartbeat_socket.connect("tcp://" + job.worker_heartbeat_address) job.is_alive = True while job.is_alive and self.master_is_alive and self.worker_is_alive: try: job_heartbeat_socket.send_multipart( [remote_constants.HEARTBEAT_TAG]) _ = job_heartbeat_socket.recv_multipart() time.sleep(remote_constants.HEARTBEAT_INTERVAL_S) except zmq.error.Again as e: job.is_alive = False logger.warning( "[Worker] lost connection with the job:{}".format( job.job_address)) if self.master_is_alive and self.worker_is_alive: self._kill_job(job.job_address) except zmq.error.ZMQError as e: break job_heartbeat_socket.close(0)
def get_ip_address(): """ get the IP address of the host. """ # Windows if _IS_WINDOWS: local_ip = socket.gethostbyname(socket.gethostname()) else: # Linux and MacOS local_ip = None try: # First way, tested in Ubuntu and MacOS s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.connect(("8.8.8.8", 80)) local_ip = s.getsockname()[0] s.close() except: # Second way, tested in CentOS try: local_ip = socket.gethostbyname(socket.gethostname()) except: pass if local_ip == None or local_ip == '127.0.0.1' or local_ip == '127.0.1.1': logger.warning( 'get_ip_address failed, please set ip address manually.') return None return local_ip
def step(self, action, **kwargs): obs, r, done, info = self.env.step(action, **kwargs) # early stop condition if info['target_changed']: info['timeout'] = True done = True logger.warning( '[FirstTarget Wrapper] early stop since first target is finished.' ) return obs, r, done, info
def clear(self): """Remove all the jobs""" self._lock.acquire() for job in self.jobs.values(): try: os.kill(job.pid, signal.SIGTERM) except OSError: logger.warning("job:{} has been killed before".format(job.pid)) logger.info("[Worker] kills a job:{}".format(job.pid)) self.jobs = dict() self._lock.release()
def call(*args, **kwargs): global _writer if _writer is None: logdir = logger.get_dir() if logdir is None: logdir = logger.auto_set_dir(action='d') logger.warning( "[VisualDL] logdir is None, will save VisualDL files to {}\nView the data using: visualdl --logdir=./{} --host={}" .format(logdir, logdir, get_ip_address())) _writer = LogWriter(logdir=logger.get_dir()) func = getattr(_writer, func_name) func(*args, **kwargs) _writer.flush()
def call(*args, **kwargs): global _writer if _writer is None: logdir = logger.get_dir() if logdir is None: logdir = logger.auto_set_dir(action='d') logger.warning( "[tensorboard] logdir is None, will save tensorboard files to {}\nView the data using: tensorboard --logdir=./{} --host={}" .format(logdir, logdir, get_ip_address())) _writer = SummaryWriter(logdir=logger.get_dir()) func = getattr(_writer, func_name) func(*args, **kwargs) _writer.flush()
def request_cpu_resource(self, global_client, max_memory): """Try to request cpu resource for 1 second/time for 300 times.""" cnt = 300 while cnt > 0: job_address = global_client.submit_job(max_memory) if job_address is not None: return job_address if cnt % 30 == 0: logger.warning( "No vacant cpu resources at the moment, " "will try {} times later.".format(cnt)) cnt -= 1 return None
def _reply_worker_heartbeat(self, socket): """create a socket that replies heartbeat signals from the worker. If the worker has exited, the job will exit automatically. """ while True: try: message = socket.recv_multipart() socket.send_multipart([remote_constants.HEARTBEAT_TAG]) except zmq.error.Again as e: logger.warning("[Job] Cannot connect to the worker{}. ".format( self.worker_address) + "Job will quit.") break socket.close(0) os._exit(1)
def loadTrainExamples(self): modelFile = os.path.join(self.args.load_folder_file[0], self.args.load_folder_file[1]) examplesFile = modelFile + ".examples" if not os.path.isfile(examplesFile): logger.warning( "File {} with trainExamples not found!".format(examplesFile)) r = input("Continue? [y|n]") if r != "y": sys.exit() else: logger.info("File with trainExamples found. Loading it...") with open(examplesFile, "rb") as f: self.trainExamplesHistory = Unpickler(f).load() logger.info('Loading done!')
def is_gpu_available(): """ check whether parl can access a GPU Returns: True if a gpu device can be found. """ ret = get_gpu_count() > 0 if _HAS_FLUID: from paddle import fluid if ret is True and not fluid.is_compiled_with_cuda(): logger.warning("Found non-empty CUDA_VISIBLE_DEVICES. \ But PARL found that Paddle was not complied with CUDA, which may cause issues. \ Thus PARL will not use GPU.") return False return ret
def _reply_heartbeat(self, target): """Worker will kill its jobs when it lost connection with the master. """ socket = self.ctx.socket(zmq.REP) socket.linger = 0 socket.setsockopt(zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000) heartbeat_master_port =\ socket.bind_to_random_port("tcp://*") self.master_heartbeat_address = "{}:{}".format(self.worker_ip, heartbeat_master_port) logger.set_dir( os.path.expanduser('~/.parl_data/worker/{}'.format( self.master_heartbeat_address.replace(':', '_')))) self.heartbeat_socket_initialized.set() logger.info("[Worker] Connect to the master node successfully. " "({} CPUs)".format(self.cpu_num)) while self.master_is_alive and self.worker_is_alive: try: message = socket.recv_multipart() worker_status = self._get_worker_status() socket.send_multipart([ remote_constants.HEARTBEAT_TAG, to_byte(str(worker_status[0])), to_byte(str(worker_status[1])), to_byte(worker_status[2]), to_byte(str(worker_status[3])) ]) except zmq.error.Again as e: self.master_is_alive = False except zmq.error.ContextTerminated as e: break socket.close(0) logger.warning( "[Worker] lost connection with the master, will exit reply heartbeat for master." ) self.worker_status.clear() self.log_server_proc.kill() self.log_server_proc.wait() # exit the worker self.worker_is_alive = False self.exit()
def remove_job(self, killed_job): """Rmove a job from internal job pool. Args: killed_job(str): Job address to be removed. Returns: True if removing the job succeeds. """ ret = False self._lock.acquire() if killed_job in self.jobs: pid = self.jobs[killed_job].pid self.jobs.pop(killed_job) ret = True try: os.kill(pid, signal.SIGTERM) except OSError: logger.warning("job:{} has been killed before".format(pid)) logger.info("[Worker] kills a job:{}".format(killed_job)) self._lock.release() return ret
def _create_sockets(self, master_address): """ Each client has 1 sockets as start: (1) submit_job_socket: submits jobs to master node. """ # submit_job_socket: submits job to master self.submit_job_socket = self.ctx.socket(zmq.REQ) self.submit_job_socket.linger = 0 self.submit_job_socket.setsockopt( zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) self.submit_job_socket.connect("tcp://{}".format(master_address)) self.start_time = time.time() thread = threading.Thread(target=self._reply_heartbeat) thread.setDaemon(True) thread.start() self.heartbeat_socket_initialized.wait() self.client_id = self.reply_master_heartbeat_address.replace(':', '_') + \ '_' + str(int(time.time())) # check if the master is connected properly try: self.submit_job_socket.send_multipart([ remote_constants.CLIENT_CONNECT_TAG, to_byte(self.reply_master_heartbeat_address), to_byte(socket.gethostname()), to_byte(self.client_id), ]) message = self.submit_job_socket.recv_multipart() self.log_monitor_url = to_str(message[1]) except zmq.error.Again as e: logger.warning("[Client] Can not connect to the master, please " "check if master is started and ensure the input " "address {} is correct.".format(master_address)) self.master_is_alive = False raise Exception("Client can not connect to the master, please " "check if master is started and ensure the input " "address {} is correct.".format(master_address))
def step(self, action, **kwargs): r = 0.0 merge_info = {} for k in range(self.skip_num): self.frame_count += 1 obs, reward, done, info = self.env.step(action, **kwargs) r += reward for key in info.keys(): if 'reward' in key: merge_info[key] = merge_info.get(key, 0.0) + info[key] else: merge_info[key] = info[key] if info['target_changed']: logger.warning( "[FrameSkip Wrapper] early break since target is changed") break if done: break merge_info['frame_count'] = self.frame_count return obs, r, done, merge_info
def _create_client_monitor(self, client_heartbeat_address): """When a new client connects to the master, a socket is created to send heartbeat signals to the client. """ client_heartbeat_socket = self.ctx.socket(zmq.REQ) client_heartbeat_socket.linger = 0 client_heartbeat_socket.setsockopt( zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) client_heartbeat_socket.connect("tcp://" + client_heartbeat_address) client_is_alive = True while client_is_alive and self.master_is_alive: try: client_heartbeat_socket.send_multipart( [remote_constants.HEARTBEAT_TAG]) client_status = client_heartbeat_socket.recv_multipart() self.cluster_monitor.update_client_status( client_status, client_heartbeat_address, self.client_hostname[client_heartbeat_address]) except zmq.error.Again as e: client_is_alive = False self.cluster_monitor.drop_client_status( client_heartbeat_address) logger.warning("[Master] cannot connect to the client " + "{}. ".format(client_heartbeat_address) + "Please check if it is still alive.") time.sleep(remote_constants.HEARTBEAT_INTERVAL_S) logger.warning("Master exits client monitor for {}.\n".format( client_heartbeat_address)) logger.info( "Master connects to {} workers and have {} vacant CPUs.\n".format( self.worker_num, self.cpu_num)) client_heartbeat_socket.close(0)
def _reply_client_heartbeat(self, socket): """Create a socket that replies heartbeat signals from the client. If the job losts connection with the client, it will exit too. """ while True: try: message = socket.recv_multipart() stop_job = self._check_used_memory() socket.send_multipart([ remote_constants.HEARTBEAT_TAG, to_byte(str(stop_job)), to_byte(self.job_address) ]) if stop_job == True: logger.error( "Memory used by this job exceeds {}. This job will exist." .format(self.max_memory)) time.sleep(5) socket.close(0) os._exit(1) except zmq.error.Again as e: logger.warning( "[Job] Cannot connect to the client. This job will exit and inform the worker." ) break socket.close(0) with self.lock: self.kill_job_socket.send_multipart( [remote_constants.KILLJOB_TAG, to_byte(self.job_address)]) try: _ = self.kill_job_socket.recv_multipart() except zmq.error.Again as e: pass logger.warning("[Job]lost connection with the client, will exit") os._exit(1)
def _reply_heartbeat(self): """Reply heartbeat signals to the master node.""" socket = self.ctx.socket(zmq.REP) socket.linger = 0 socket.setsockopt(zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000) reply_master_heartbeat_port =\ socket.bind_to_random_port(addr="tcp://*") self.reply_master_heartbeat_address = "{}:{}".format( get_ip_address(), reply_master_heartbeat_port) self.heartbeat_socket_initialized.set() connected = False while self.client_is_alive and self.master_is_alive: try: message = socket.recv_multipart() elapsed_time = datetime.timedelta(seconds=int(time.time() - self.start_time)) socket.send_multipart([ remote_constants.HEARTBEAT_TAG, to_byte(self.executable_path), to_byte(str(self.actor_num)), to_byte(str(elapsed_time)), to_byte(str(self.log_monitor_url)), ]) # TODO: remove additional information except zmq.error.Again as e: if connected: logger.warning("[Client] Cannot connect to the master." "Please check if it is still alive.") else: logger.warning( "[Client] Cannot connect to the master." "Please check the firewall between client and master.(e.g., ping the master IP)" ) self.master_is_alive = False socket.close(0) logger.warning("Client exit replying heartbeat for master.")
def single_task(self, obj, reply_socket, job_address): """An infinite loop waiting for commands from the remote object. Each job will receive two kinds of message from the remote object: 1. When the remote object calls a function, job will run the function on the local instance and return the results to the remote object. 2. When the remote object is deleted, the job will quit and release related computation resources. Args: reply_socket (sockert): main socket to accept commands of remote object. job_address (String): address of reply_socket. """ while True: message = reply_socket.recv_multipart() tag = message[0] if tag in [ remote_constants.CALL_TAG, remote_constants.GET_ATTRIBUTE, remote_constants.SET_ATTRIBUTE ]: # if tag == remote_constants.CALL_TAG: try: if tag == remote_constants.CALL_TAG: function_name = to_str(message[1]) data = message[2] args, kwargs = loads_argument(data) # Redirect stdout to stdout.log temporarily logfile_path = os.path.join(self.log_dir, 'stdout.log') with redirect_stdout_to_file(logfile_path): ret = getattr(obj, function_name)(*args, **kwargs) ret = dumps_return(ret) reply_socket.send_multipart( [remote_constants.NORMAL_TAG, ret]) elif tag == remote_constants.GET_ATTRIBUTE: attribute_name = to_str(message[1]) logfile_path = os.path.join(self.log_dir, 'stdout.log') with redirect_stdout_to_file(logfile_path): ret = getattr(obj, attribute_name) ret = dumps_return(ret) reply_socket.send_multipart( [remote_constants.NORMAL_TAG, ret]) else: attribute_name = to_str(message[1]) attribute_value = loads_return(message[2]) logfile_path = os.path.join(self.log_dir, 'stdout.log') with redirect_stdout_to_file(logfile_path): setattr(obj, attribute_name, attribute_value) reply_socket.send_multipart( [remote_constants.NORMAL_TAG]) except Exception as e: # reset the job error_str = str(e) logger.error(error_str) if type(e) == AttributeError: reply_socket.send_multipart([ remote_constants.ATTRIBUTE_EXCEPTION_TAG, to_byte(error_str) ]) raise AttributeError elif type(e) == SerializeError: reply_socket.send_multipart([ remote_constants.SERIALIZE_EXCEPTION_TAG, to_byte(error_str) ]) raise SerializeError elif type(e) == DeserializeError: reply_socket.send_multipart([ remote_constants.DESERIALIZE_EXCEPTION_TAG, to_byte(error_str) ]) raise DeserializeError else: traceback_str = str(traceback.format_exc()) logger.error("traceback:\n{}".format(traceback_str)) reply_socket.send_multipart([ remote_constants.EXCEPTION_TAG, to_byte(error_str + "\ntraceback:\n" + traceback_str) ]) break # receive DELETE_TAG from actor, and stop replying worker heartbeat elif tag == remote_constants.KILLJOB_TAG: reply_socket.send_multipart([remote_constants.NORMAL_TAG]) logger.warning( "An actor exits and this job {} will exit.".format( job_address)) break else: logger.error( "The job receives an unknown message: {}".format(message)) raise NotImplementedError
def learn(self): """Each iteration: 1. Performs numEps episodes of self-play. 2. Retrains neural network with examples in trainExamplesHistory (which has a maximum length of numItersForTrainExamplesHistory). 3. Evaluates the new neural network with the test dataset. 4. Pits the new neural network against the old one and accepts it only if it wins >= updateThreshold fraction of games. """ # create remote actors to run tasks (self-play/pitting/evaluate_test_dataset) in parallel. self._create_remote_actors() for iteration in range(1, self.args.numIters + 1): logger.info('Starting Iter #{} ...'.format(iteration)) #################### logger.info('Step1: self-play in parallel...') iterationTrainExamples = [] # update weights of remote actors to the latest weights, and ask them to run self-play task for signal_queue in self.remote_actors_signal_queues: signal_queue.put({"task": "self-play"}) # wait for all remote actors (a total of self.args.actors_num) to return the self-play results for _ in range(self.args.actors_num): result = self.remote_actors_return_queue.get() iterationTrainExamples.extend(result["self-play"]) # save the iteration examples to the history self.trainExamplesHistory.append(iterationTrainExamples) if len(self.trainExamplesHistory ) > self.args.numItersForTrainExamplesHistory: logger.warning("Removing the oldest entry in trainExamples.") self.trainExamplesHistory.pop(0) self.saveTrainExamples(iteration) # backup history to a file #################### logger.info('Step2: train neural network...') # shuffle examples before training trainExamples = [] for e in self.trainExamplesHistory: trainExamples.extend(e) shuffle(trainExamples) # training new network, keeping a copy of the old one self.current_agent.save( os.path.join(self.args.checkpoint, 'temp.pth.tar')) self.previous_agent.restore( os.path.join(self.args.checkpoint, 'temp.pth.tar')) self.current_agent.learn(trainExamples) #################### logger.info('Step3: evaluate test dataset in parallel...') cnt = 0 # update weights of remote actors to the latest weights, and ask them to evaluate assigned test dataset for i, data in enumerate( split_group(self.test_dataset, len(self.test_dataset) // self.args.actors_num)): self.remote_actors_signal_queues[i].put({ "task": "evaluate_test_dataset", "test_dataset": data }) cnt += len(data) perfect_moves_cnt, good_moves_cnt = 0, 0 # wait for all remote actors (a total of self.args.actors_num) to return the evaluating results for _ in range(self.args.actors_num): (perfect_moves, good_moves) = self.remote_actors_return_queue.get( )["evaluate_test_dataset"] perfect_moves_cnt += perfect_moves good_moves_cnt += good_moves logger.info('perfect moves rate: {}, good moves rate: {}'.format( perfect_moves_cnt / cnt, good_moves_cnt / cnt)) tensorboard.add_scalar('perfect_moves_rate', perfect_moves_cnt / cnt, iteration) tensorboard.add_scalar('good_moves_rate', good_moves_cnt / cnt, iteration) #################### logger.info( 'Step4: pitting against previous generation in parallel...') # transfer weights of previous generation and current generation to the remote actors, and ask them to pit. for signal_queue in self.remote_actors_signal_queues: signal_queue.put({"task": "pitting"}) previous_wins, current_wins, draws = 0, 0, 0 for _ in range(self.args.actors_num): (pwins_, cwins_, draws_) = self.remote_actors_return_queue.get()["pitting"] previous_wins += pwins_ current_wins += cwins_ draws += draws_ logger.info('NEW/PREV WINS : %d / %d ; DRAWS : %d' % (current_wins, previous_wins, draws)) if previous_wins + current_wins == 0 or float(current_wins) / ( previous_wins + current_wins) < self.args.updateThreshold: logger.info('REJECTING NEW MODEL') self.current_agent.restore( os.path.join(self.args.checkpoint, 'temp.pth.tar')) else: logger.info('ACCEPTING NEW MODEL') self.current_agent.save( os.path.join(self.args.checkpoint, 'best.pth.tar')) self.current_agent.save( os.path.join(self.args.checkpoint, self.getCheckpointFile(iteration)))
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from parl.utils.utils import _HAS_FLUID, _HAS_TORCH from parl.utils import logger if _HAS_FLUID: from parl.algorithms.fluid import * elif _HAS_TORCH: from parl.algorithms.torch import * else: logger.warning( "No deep learning framework was found, but it's ok for parallel computation." )
def decorator(cls): # we are not going to create a remote actor in job.py if 'XPARL' in os.environ and os.environ['XPARL'] == 'True': logger.warning( "Note: this object will be runnning as a local object") return cls class RemoteWrapper(object): """ Wrapper for remote class in client side. """ def __init__(self, *args, **kwargs): """ Args: args, kwargs: arguments for the initialization of the unwrapped class. """ self.GLOBAL_CLIENT = get_global_client() self.ctx = self.GLOBAL_CLIENT.ctx # GLOBAL_CLIENT will set `master_is_alive` to False when hearbeat # finds the master is dead. if self.GLOBAL_CLIENT.master_is_alive: job_address = self.request_cpu_resource( self.GLOBAL_CLIENT, max_memory) else: raise Exception("Can not submit job to the master. " "Please check if master is still alive.") if job_address is None: raise ResourceError("Cannot submit the job to the master. " "Please add more CPU resources to the " "master or try again later.") self.internal_lock = threading.Lock() # Send actor commands like `init` and `call` to the job. self.job_socket = self.ctx.socket(zmq.REQ) self.job_socket.linger = 0 self.job_socket.connect("tcp://{}".format(job_address)) self.job_address = job_address self.job_shutdown = False self.send_file(self.job_socket) file_name = inspect.getfile(cls)[:-3] cls_source = inspect.getsourcelines(cls) end_of_file = cls_source[1] + len(cls_source[0]) class_name = cls.__name__ self.job_socket.send_multipart([ remote_constants.INIT_OBJECT_TAG, cloudpickle.dumps([file_name, class_name, end_of_file]), cloudpickle.dumps([args, kwargs]), ]) message = self.job_socket.recv_multipart() tag = message[0] if tag == remote_constants.EXCEPTION_TAG: traceback_str = to_str(message[1]) self.job_shutdown = True raise RemoteError('__init__', traceback_str) def __del__(self): """Delete the remote class object and release remote resources.""" try: self.job_socket.setsockopt(zmq.RCVTIMEO, 1 * 1000) except AttributeError: pass if not self.job_shutdown: try: self.job_socket.send_multipart( [remote_constants.KILLJOB_TAG]) _ = self.job_socket.recv_multipart() self.job_socket.close(0) except AttributeError: pass except zmq.error.ZMQError: pass except TypeError: pass def send_file(self, socket): try: socket.send_multipart([ remote_constants.SEND_FILE_TAG, self.GLOBAL_CLIENT.pyfiles ]) _ = socket.recv_multipart() except zmq.error.Again as e: logger.error("Send python files failed.") def request_cpu_resource(self, global_client, max_memory): """Try to request cpu resource for 1 second/time for 300 times.""" cnt = 300 while cnt > 0: job_address = global_client.submit_job(max_memory) if job_address is not None: return job_address if cnt % 30 == 0: logger.warning( "No vacant cpu resources at the moment, " "will try {} times later.".format(cnt)) cnt -= 1 return None def __setattr__(self, attr, value): if attr not in cls().__dict__: super().__setattr__(attr, value) else: self.internal_lock.acquire() self.job_socket.send_multipart([ remote_constants.SET_ATTRIBUTE, to_byte(attr), dumps_return(value) ]) message = self.job_socket.recv_multipart() tag = message[0] self.internal_lock.release() if tag == remote_constants.NORMAL_TAG: pass else: self.job_shutdown = True raise NotImplementedError() return def __getattr__(self, attr): """Call the function of the unwrapped class.""" #check if attr is a function or not if attr in cls().__dict__: self.internal_lock.acquire() self.job_socket.send_multipart( [remote_constants.GET_ATTRIBUTE, to_byte(attr)]) message = self.job_socket.recv_multipart() tag = message[0] if tag == remote_constants.NORMAL_TAG: ret = loads_return(message[1]) self.internal_lock.release() return ret elif tag == remote_constants.EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteError(attr, error_str) elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteAttributeError(attr, error_str) elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteSerializeError(attr, error_str) elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteDeserializeError(attr, error_str) else: self.job_shutdown = True raise NotImplementedError() def wrapper(*args, **kwargs): if self.job_shutdown: raise RemoteError( attr, "This actor losts connection with the job.") self.internal_lock.acquire() data = dumps_argument(*args, **kwargs) self.job_socket.send_multipart( [remote_constants.CALL_TAG, to_byte(attr), data]) message = self.job_socket.recv_multipart() tag = message[0] if tag == remote_constants.NORMAL_TAG: ret = loads_return(message[1]) elif tag == remote_constants.EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteError(attr, error_str) elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteAttributeError(attr, error_str) elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteSerializeError(attr, error_str) elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG: error_str = to_str(message[1]) self.job_shutdown = True raise RemoteDeserializeError(attr, error_str) else: self.job_shutdown = True raise NotImplementedError() self.internal_lock.release() return ret return wrapper RemoteWrapper._original = cls return RemoteWrapper