def _worker_loop(dataset, job_queue: mp.Queue, result_queue: mp.Queue, interrupt_event: mp.Event): logger = logging.getLogger("worker_loop") logger.debug("Worker started.") while True: logger.debug("Trying to fetch from job_queue.") if interrupt_event.is_set(): logger.debug("Received interrupt signal, breaking.") break try: # This assumes that the job_queue is fully populated before the worker is started. index = job_queue.get_nowait() logger.debug("Fetch successful.") except Empty: logger.debug("Queue empty, setting up poison pill.") index = None if index is None or interrupt_event.is_set(): logger.debug( "Fetched poison pill or received interrupt signal, breaking.") break try: logger.debug("Sampling index {} from dataset.".format(index)) sample = dataset[index] except Exception: logger.debug("Dataset threw an exception.".format(index), exc_info=1) result_queue.put((index, ExceptionWrapper(sys.exc_info()))) else: logger.debug( "Putting sample at index {} in the result queue.".format( index)) result_queue.put((index, sample))
def log_fn(self, stop_event: Event): try: self._super_create_loggers() self.resposne_queue.put({ k: self.__dict__[k] for k in ["save_dir", "tb_logdir", "is_sweep"] }) while True: try: cmd = self.draw_queue.get(True, 0.1) except EmptyQueue: if stop_event.is_set(): break else: continue self._super_log(*cmd) self.resposne_queue.put(True) except: print("Logger process crashed.") raise finally: print("Logger: syncing") if self.use_wandb: wandb.join() stop_event.set() print("Logger process terminating...")
def _prefetch(in_queue: mp.Queue, out_queue: mp.Queue, batchsize: int, shutdown_event: mp.Event, target_device, waiting_time=5): """Continuously prefetches complete trajectories dropped by the :py:class:`~.TrajectoryStore` for training. As long as shutdown is not set, this method pulls :py:attr:`batchsize` trajectories from :py:attr:`in_queue`, transforms them into batches using :py:meth:`~_to_batch()` and puts them onto the :py:attr:`out_queue`. This usually runs as an asynchronous :py:obj:`multiprocessing.Process`. Parameters ---------- in_queue: :py:obj:`multiprocessing.Queue` A queue that delivers dropped trajectories from :py:class:`~.TrajectoryStore`. out_queue: :py:obj:`multiprocessing.Queue` A queue that delivers batches to :py:meth:`_loop()`. batchsize: `int` The number of trajectories that shall be processed into a batch. shutdown_event: :py:obj:`multiprocessing.Event` An event that breaks this methods internal loop. target_device: :py:obj:`torch.device` The target device of the batch. waiting_time: `float` Time the methods loop sleeps between each iteration. """ while not shutdown_event.is_set(): try: trajectories = [ in_queue.get(timeout=waiting_time) for _ in range(batchsize) ] except queue.Empty: continue batch = Learner._to_batch(trajectories, target_device) # delete Tensors after usage to free memory (see torch multiprocessing) del trajectories try: out_queue.put(batch) except (AssertionError, ValueError): # queue closed continue # delete Tensors after usage to free memory (see torch multiprocessing) del batch try: del trajectories except UnboundLocalError: # already deleted pass
class WorkerManager: def __init__(self, n_workers, actor, args): self._now_episode = Value('i', 0) self.queue = Queue() self.collect_event = Event() self.worker = [] for i in range(n_workers): self.worker.append( Worker(self.queue, self.collect_event, actor, args)) time.sleep(1) self.process = [ Process(target=self.worker[i].run, args=(self._now_episode, )) for i in range(n_workers) ] for p in self.process: p.start() print(f'Start {n_workers} workers.') def collect(self): result = [] self.collect_event.set() while self.collect_event.is_set(): # WAIT FOR DATA COLLECT END pass for w in self.worker: w.event.wait() while not self.queue.empty(): result.append(self.queue.get()) for w in self.worker: w.event.clear() return result def now_episode(self): value = self._now_episode.value return value
class AsyncLogger(Logger): @staticmethod def log_fn(self, stop_event: Event): try: self._super_create_loggers() self.resposne_queue.put({ k: self.__dict__[k] for k in ["save_dir", "tb_logdir", "is_sweep"] }) while True: try: cmd = self.draw_queue.get(True, 0.1) except EmptyQueue: if stop_event.is_set(): break else: continue self._super_log(*cmd) self.resposne_queue.put(True) except: print("Logger process crashed.") raise finally: print("Logger: syncing") if self.use_wandb: wandb.join() stop_event.set() print("Logger process terminating...") def create_loggers(self): self._super_create_loggers = super().create_loggers self.stop_event = Event() self.proc = Process(target=self.log_fn, args=(self, self.stop_event)) self.proc.start() atexit.register(self.finish) def __init__(self, *args, **kwargs): self.queue = [] self.draw_queue = Queue() self.resposne_queue = Queue() self._super_log = super().log self.waiting = 0 super().__init__(*args, **kwargs) self.__dict__.update(self.resposne_queue.get(True)) def log(self, plotlist, step=None): if self.stop_event.is_set(): return if not isinstance(plotlist, list): plotlist = [plotlist] plotlist = [p for p in plotlist if p] if not plotlist: return plotlist = U.apply_to_tensors(plotlist, lambda x: x.detach().cpu()) self.queue.append((plotlist, step)) self.flush(wait=False) def enqueue(self, data, step: Optional[int]): self.draw_queue.put((data, step)) self.waiting += 1 def wait_logger(self, wait=False): cond = (lambda: not self.resposne_queue.empty()) if not wait else ( lambda: self.waiting > 0) already_printed = False while cond() and not self.stop_event.is_set(): will_wait = self.resposne_queue.empty() if will_wait and not already_printed: already_printed = True sys.stdout.write("Warning: waiting for logger... ") sys.stdout.flush() try: self.resposne_queue.get(True, 0.2) except EmptyQueue: continue self.waiting -= 1 if already_printed: print("done.") def flush(self, wait: bool = True): while self.queue: plotlist, step = self.queue[0] for i, p in enumerate(plotlist): if isinstance(p, PlotAsync): res = p.get(wait) if res is not None: plotlist[i] = res else: if wait: assert p.failed # Exception in the worker thread print( "Exception detected in a PlotAsync object. Syncing logger and ignoring further plots." ) self.wait_logger(True) self.stop_event.set() self.proc.join() return self.queue.pop(0) self.enqueue(plotlist, step) self.wait_logger(wait) def finish(self): if self.stop_event.is_set(): return self.flush(True) self.stop_event.set() self.proc.join()
class DataQueue(object): '''Queue for data prefetching DataQueue launch a subprocess to avoid python's GIL # Arguments generator: instance of generator which feeds data infinitely max_queue_size: maximum queue size nb_worker: control concurrency, only take effect when do preprocessing ''' def __init__(self, generator, max_queue_size=5, nb_worker=1): self.generator = generator self.nb_worker = nb_worker self.max_queue_size = max_queue_size self._queue = Queue() self._signal = Event() self._available_cv = Condition() self._full_cv = Condition() args = (generator, self._queue, self._signal, self._available_cv, self._full_cv, self.nb_worker, self.max_queue_size) self.working_process = Process(target=self.generator_process, args=args) self.working_process.daemon = True self.working_process.start() def get(self, timeout=None): with self._available_cv: if not self._signal.is_set() and self._queue.qsize() == 0: self._available_cv.wait() if self._signal.is_set(): raise Exception("prefetch process terminated!") try: data = self._queue.get() with self._full_cv: self._full_cv.notify() except Exception as e: with self._full_cv: self._signal.set() self._full_cv.notify_all() raise e return data def qsize(self): return self._queue.qsize() def __del__(self): with self._full_cv: self._signal.set() self._full_cv.notify_all() #self.working_process.terminate() self.working_process.join() @staticmethod def generator_process(generator, queue, signal, available_cv, full_cv, nb_worker, max_qsize): preprocess = generator.preprocess generator = BackgroundGenerator(generator()) # invoke call() # put data in the queue def enqueue_fn(generator, preprocess, queue, signal, available_cv, full_cv, lock, max_qsize): while True: try: with lock: data = next(generator) data = preprocess(data) if not isinstance(data, types.GeneratorType): data = [data] for ele in data: ele = np2tensor(ele) # numpy array to pytorch's tensor with full_cv: while not signal.is_set( ) and queue.qsize() >= max_qsize: full_cv.wait() if signal.is_set(): return queue.put(ele) with available_cv: available_cv.notify() except Exception as e: print("Error Message", e, file=sys.stderr) with full_cv: signal.set() full_cv.notify_all() with available_cv: signal.set() available_cv.notify_all() raise Exception("generator thread went wrong!") # start threads lock = threading.Lock() args = (generator, preprocess, queue, signal, available_cv, full_cv, lock, max_qsize) generator_threads = [ threading.Thread(target=enqueue_fn, args=args) for _ in range(nb_worker) ] for thread in generator_threads: thread.daemon = True thread.start() for thread in generator_threads: thread.join()
class Worker(Process): def __init__(self, worker_id, args): super().__init__() self.id = worker_id self.args = args # for master use, for worker use self.pipe_master, self.pipe_worker = Pipe() self.exit_event = Event() # determine n_e q, r = divmod(args.n_e, args.n_w) if r: print('Warning: n_e % n_w != 0') if worker_id == args.n_w - 1: self.n_e = n_e = q + r else: self.n_e = n_e = q print('Worker', self.id, '] n_e = %d' % n_e) self.env_start = worker_id * q self.env_slice = slice(self.env_start, self.env_start + n_e) self.env_range = range(self.env_start, self.env_start + n_e) self.envs = None self.start() def make_environments(self): envs = [] for _ in range(self.n_e): envs.append(gym.make(self.args.env, hack='train')) return envs def put_shared_tensors(self, actions, obs, rewards, terminals): assert (actions.is_shared() and obs.is_shared() and rewards.is_shared() and terminals.is_shared()) self.pipe_master.send((actions, obs, rewards, terminals)) def get_shared_tensors(self): actions, obs, rewards, terminals = self.pipe_worker.recv() assert (actions.is_shared() and obs.is_shared() and rewards.is_shared() and terminals.is_shared()) return actions, obs, rewards, terminals def set_step_done(self): self.pipe_worker.send_bytes(b'1') def wait_step_done(self): self.pipe_master.recv_bytes(1) def set_action_done(self): self.pipe_master.send_bytes(b'1') def wait_action_done(self): self.pipe_worker.recv_bytes(1) def run(self): preprocess = PAACNet.preprocess envs = self.envs = self.make_environments() env_start = self.env_start t_max = self.args.t_max t = 0 dones = [False] * self.args.n_e # get shared tensor actions, obs, rewards, terminals = self.get_shared_tensors() for i, env in enumerate(envs, start=env_start): obs[i] = preprocess(env.reset()) self.set_step_done() while not self.exit_event.is_set(): self.wait_action_done() for i, env in enumerate(envs, start=env_start): if not dones[i]: ob, reward, done, info = env.step(actions[i]) else: ob, reward, done, info = env.reset(), 0, False, None obs[i] = preprocess(ob) rewards[t, i] = reward terminals[t, i] = dones[i] = done self.set_step_done() t += 1 if t == t_max: t = 0
class Sink(Process): def __init__(self, port_out, front_sink_addr, verbose=False): super().__init__() self.port = port_out self.exit_flag = Event() self.logger = set_logger(colored('SINK', 'green'), verbose) self.front_sink_addr = front_sink_addr self.is_ready = Event() self.verbose = verbose def close(self): self.logger.info('shutting down...') self.is_ready.clear() self.exit_flag.set() self.terminate() self.join() self.logger.info('terminated!') def run(self): self._run() @zmqd.socket(zmq.PULL) @zmqd.socket(zmq.PAIR) @zmqd.socket(zmq.PUB) def _run(self, receiver, frontend, sender): receiver_addr = auto_bind(receiver) frontend.connect(self.front_sink_addr) sender.bind('tcp://*:%d' % self.port) pending_jobs: Dict[str, SinkJob] = defaultdict(lambda: SinkJob()) poller = zmq.Poller() poller.register(frontend, zmq.POLLIN) poller.register(receiver, zmq.POLLIN) # send worker receiver address back to frontend frontend.send(receiver_addr.encode('ascii')) # Windows does not support logger in MP environment, thus get a new logger # inside the process for better compability logger = set_logger(colored('SINK', 'green'), self.verbose) logger.info('ready') self.is_ready.set() while not self.exit_flag.is_set(): socks = dict(poller.poll()) if socks.get(receiver) == zmq.POLLIN: msg = receiver.recv_multipart() job_id = msg[0] # parsing job_id and partial_id job_info = job_id.split(b'@') job_id = job_info[0] partial_id = int(job_info[1]) if len(job_info) == 2 else 0 if msg[2] == ServerCmd.data_embed: x = jsonapi.loads(msg[1]) pending_jobs[job_id].add_output(x, partial_id) else: logger.error( 'received a wrongly-formatted request (expected 4 frames, got %d)' % len(msg)) logger.error('\n'.join('field %d: %s' % (idx, k) for idx, k in enumerate(msg)), exc_info=True) logger.info('collect %s %s (E:%d/A:%d)' % (msg[2], job_id, pending_jobs[job_id].progress_outputs, pending_jobs[job_id].checksum)) # check if there are finished jobs, then send it back to workers finished = [(k, v) for k, v in pending_jobs.items() if v.is_done] for job_info, tmp in finished: client_addr, req_id = job_info.split(b'#') x = tmp.result sender.send_multipart([client_addr, x, req_id]) logger.info('send back\tsize: %d\tjob id: %s' % (tmp.checksum, job_info)) # release the job tmp.clear() pending_jobs.pop(job_info) if socks.get(frontend) == zmq.POLLIN: client_addr, msg_type, msg_info, req_id = frontend.recv_multipart() if msg_type == ServerCmd.new_job: job_info = client_addr + b'#' + req_id # register a new job pending_jobs[job_info].checksum = int(msg_info) logger.info('job register\tsize: %d\tjob id: %s' % (int(msg_info), job_info)) elif msg_type == ServerCmd.show_config: # dirty fix of slow-joiner: sleep so that client receiver can connect. time.sleep(0.1) logger.info('send config\tclient %s' % client_addr) sender.send_multipart([client_addr, msg_info, req_id])
class BaseComponent: def __init__(self, component_config, start_component=False): self.name = "" self.ROUTINES_FOLDER_PATH = "pipert/contrib/routines" self.MONITORING_SYSTEMS_FOLDER_PATH = "pipert/contrib/metrics_collectors" self.use_memory = False self.stop_event = Event() self.stop_event.set() self.queues = {} self._routines = {} self.metrics_collector = NullCollector() self.parent_logger = None self.logger = None self.setup_component(component_config) self.metrics_collector.setup() if start_component: self.run_comp() def setup_component(self, component_config): if (component_config is None) or (type(component_config) is not dict) or\ (component_config == {}): return component_name, component_parameters = list(component_config.items())[0] self.name = component_name self.parent_logger = create_parent_logger(self.name) self.logger = self.parent_logger.getChild(self.name) if ("shared_memory" in component_parameters) and \ (component_parameters["shared_memory"]): self.use_memory = True self.generator = smGen(self.name) if "monitoring_system" in component_parameters: self.set_monitoring_system(component_parameters["monitoring_system"]) for queue in component_parameters["queues"]: self.create_queue(queue_name=queue, queue_size=1) routine_factory = ClassFactory(self.ROUTINES_FOLDER_PATH) for routine_name, routine_parameters_real in component_parameters["routines"].items(): routine_parameters = routine_parameters_real.copy() routine_parameters["name"] = routine_name routine_parameters['metrics_collector'] = self.metrics_collector routine_parameters["logger"] = self.parent_logger.getChild(routine_name) routine_class = routine_factory.get_class(routine_parameters.pop("routine_type_name", "")) if routine_class is None: continue try: self._replace_queue_names_with_queue_objects(routine_parameters) except QueueDoesNotExist as e: continue routine_parameters["component_name"] = self.name self.register_routine(routine_class(**routine_parameters).as_thread()) def _replace_queue_names_with_queue_objects(self, routine_parameters_kwargs): for key, value in routine_parameters_kwargs.items(): if 'queue' in key.lower(): routine_parameters_kwargs[key] = self.get_queue(queue_name=value) def _start(self): """ Goes over the component's routines registered in self.routines and starts running them. """ self.logger.info("Running all routines") for routine in self._routines.values(): routine.start() self.logger.info("{0} Started".format(routine.name)) def run_comp(self): """ Starts running all the component's routines. """ self.logger.info("Running component") self.stop_event.clear() if self.use_memory and sys.version_info.minor < 8: self.generator.create_memories() self._start() gevent.signal_handler(signal.SIGTERM, self.stop_run) def register_routine(self, routine: Union[Routine, Process, Thread]): """ Registers routine to the list of component's routines Args: routine: the routine to register """ self.logger.info("Registering routine") self.logger.info(routine) # TODO - write this function in a cleaner way? if isinstance(routine, Routine): if routine.name in self._routines: self.logger.error("Routine name already exist") raise RegisteredException("routine name already exist") if routine.stop_event is None: routine.stop_event = self.stop_event if self.use_memory: routine.use_memory = self.use_memory routine.generator = self.generator else: self.logger.error("Routine is already registered") raise RegisteredException("routine is already registered") self.logger.info("Routine registered") self._routines[routine.name] = routine else: self.logger.info("Routine registered") self._routines[routine.__str__()] = routine def _teardown_callback(self, *args, **kwargs): """ Implemented by subclasses of BaseComponent. Used for stopping or tearing down things that are not stopped by setting the stop_event. Returns: None """ pass def stop_run(self): """ Signals all the component's routines to stop. """ self.logger.info("Stopping component") if self.stop_event.is_set(): return 0 self.stop_event.set() try: self._teardown_callback() if self.use_memory: self.logger.info("Cleaning shared memory") self.generator.cleanup() for routine in self._routines.values(): self.logger.info("Stopping routine {0}".format(routine.name)) if isinstance(routine, Routine): routine.runner.join() elif isinstance(routine, (Process, Thread)): routine.join() self.logger.info("Routine {0} stopped".format(routine.name)) return 0 except RuntimeError: return 1 def create_queue(self, queue_name, queue_size=1): """ Create a new queue for the component. Returns True if created or False otherwise Args: queue_name: the name of the queue, must be unique queue_size: the size of the queue """ if queue_name in self.queues: return False self.queues[queue_name] = Queue(maxsize=queue_size) return True def get_queue(self, queue_name): """ Returns the queue object by its name Args: queue_name: the name of the queue Raises: KeyError - if no queue has the name """ try: return self.queues[queue_name] except KeyError: raise QueueDoesNotExist(queue_name) def get_all_queue_names(self): """ Returns the list of names of queues that the component expose. """ return list(self.queues.keys()) def does_queue_exist(self, queue_name): """ Returns True the component has a queue named queue_name or False otherwise Args: queue_name: the name of the queue to check """ return queue_name in self.queues def delete_queue(self, queue_name): """ Deletes a queue with the name queue_name. Returns True if succeeded. Args: queue_name: the name of the queue to delete Raises: KeyError - if no queue has the name queue_name """ if queue_name not in self.queues: raise QueueDoesNotExist(queue_name) if self.does_routines_use_queue(queue_name=queue_name): return False try: del self.queues[queue_name] return True except KeyError: raise QueueDoesNotExist(queue_name) def does_routine_name_exist(self, routine_name): return routine_name in self._routines def remove_routine(self, routine_name): if self.does_routine_name_exist(routine_name): del self._routines[routine_name] return True else: return False def does_routines_use_queue(self, queue_name): for routine in self._routines.values(): if routine.does_routine_use_queue(self.queues[queue_name]): return True return False def is_component_running(self): return not self.stop_event.is_set() def get_routines(self): return self._routines def get_component_configuration(self): component_dict = { "shared_memory": self.use_memory, "queues": list(self.get_all_queue_names()), "routines": {} } if type(self).__name__ != BaseComponent.__name__: component_dict["component_type_name"] = type(self).__name__ for current_routine_object in self._routines.values(): routine_creation_dict = \ self._get_routine_creation(current_routine_object) routine_name = routine_creation_dict.pop("name") component_dict["routines"][routine_name] = \ routine_creation_dict return {self.name: component_dict} def _get_routine_creation(self, routine): routine_dict = routine.get_creation_dictionary() routine_dict["routine_type_name"] = routine.__class__.__name__ for routine_param_name in routine_dict.keys(): if "queue" in routine_param_name: for queue_name in self.queues.keys(): if getattr(routine, routine_param_name) is \ self.queues[queue_name]: routine_dict[routine_param_name] = queue_name return routine_dict def set_monitoring_system(self, monitoring_system_parameters): monitoring_system_factory = ClassFactory(self.MONITORING_SYSTEMS_FOLDER_PATH) if "name" not in monitoring_system_parameters: print("No name parameter found inside the monitoring system") return monitoring_system_name = monitoring_system_parameters.pop("name") + "Collector" monitoring_system_class = monitoring_system_factory.get_class(monitoring_system_name) if monitoring_system_class is None: return try: self.metrics_collector = monitoring_system_class(**monitoring_system_parameters) except TypeError: print("Bad parameters given for the monitoring system " + monitoring_system_name) def set_routine_attribute(self, routine_name, attribute_name, attribute_value): routine = self._routines.get(routine_name, None) if routine is not None: setattr(routine, attribute_name, attribute_value)