def __init__(self, pipe_c2s, pipe_s2c): super(SimulatorMaster, self).__init__() assert os.name != 'nt', "Doesn't support windows!" self.daemon = True self.name = 'SimulatorMaster' self.context = zmq.Context() self.c2s_socket = self.context.socket(zmq.PULL) self.c2s_socket.bind(pipe_c2s) self.c2s_socket.set_hwm(20) self.s2c_socket = self.context.socket(zmq.ROUTER) self.s2c_socket.bind(pipe_s2c) self.s2c_socket.set_hwm(20) # queueing messages to client self.send_queue = queue.Queue(maxsize=1000) def f(): msg = self.send_queue.get() self.s2c_socket.send_multipart(msg, copy=False) self.send_thread = LoopThread(f) self.send_thread.daemon = True self.send_thread.start() # make sure socket get closed at the end def clean_context(soks, context): for s in soks: s.close() context.term() import atexit atexit.register(clean_context, [self.c2s_socket, self.s2c_socket], self.context)
class SimulatorMaster(threading.Thread): """ A base thread to communicate with all StateExchangeSimulatorProcess. It should produce action for each simulator, as well as defining callbacks when a transition or an episode is finished. """ class ClientState(object): def __init__(self): self.memory = [] # list of Experience self.ident = None def __init__(self, pipe_c2s, pipe_s2c): super(SimulatorMaster, self).__init__() assert os.name != 'nt', "Doesn't support windows!" self.daemon = True self.name = 'SimulatorMaster' self.context = zmq.Context() self.c2s_socket = self.context.socket(zmq.PULL) self.c2s_socket.bind(pipe_c2s) self.c2s_socket.set_hwm(10) self.s2c_socket = self.context.socket(zmq.ROUTER) self.s2c_socket.bind(pipe_s2c) self.s2c_socket.set_hwm(10) # queueing messages to client self.send_queue = queue.Queue(maxsize=100) def f(): msg = self.send_queue.get() self.s2c_socket.send_multipart(msg, copy=False) self.send_thread = LoopThread(f) self.send_thread.daemon = True self.send_thread.start() # make sure socket get closed at the end def clean_context(soks, context): for s in soks: s.close() context.term() import atexit atexit.register(clean_context, [self.c2s_socket, self.s2c_socket], self.context) def run(self): self.clients = defaultdict(self.ClientState) try: while True: msg = loads(self.c2s_socket.recv(copy=False).bytes) ident, state, reward, isOver = msg client = self.clients[ident] if client.ident is None: client.ident = ident # maybe check history and warn about dead client? self._process_msg(client, state, reward, isOver) except zmq.ContextTerminated: logger.info("[Simulator] Context was terminated.") def __del__(self): self.context.destroy(linger=0)
def get_recv_thread(self): def f(): msg = loads(self.sim2mgr_socket.recv(copy=False).bytes) self.queue.put(msg) recv_thread = LoopThread(f, pausable=False) # recv_thread.daemon = True recv_thread.name = "recv thread" return recv_thread
def get_simulator_thread(self): # spawn a separate thread to run policy, can speed up 1.3x def populate_job_func(): self._populate_job_queue.get() with self.trainer.sess.as_default(): for _ in range(self.update_frequency): self._populate_exp() th = LoopThread(populate_job_func, pausable=False) th.name = "SimulatorThread" return th
def _before_train(self): model = self.trainer.model from tensorpack.callbacks.hooks import CallbackToHook from tensorpack.train.base import ReuseSessionCreator if len(self._callbacks) > 0: self._sess = tf.train.MonitoredSession( session_creator=ReuseSessionCreator(self.trainer.sess), hooks=[CallbackToHook(cb) for cb in self._callbacks]) else: self._sess = self.trainer.sess self.trainer.sess.run(self._op_sync_weights) self._thread = LoopThread(self._run_loop) self._thread.start()
def get_recv_thread(self): def f(): msg = self.sim2exp_socket.recv(copy=False).bytes msg = loads(msg) print('{}: received msg'.format(self.agent_name)) try: self.queue.put_nowait(msg) except Exception: logger.info('put queue failed!') # send response or not? recv_thread = LoopThread(f, pausable=False) # recv_thread.daemon = True recv_thread.name = "recv thread" return recv_thread
def _before_train(self): for p in self.predictors: self.predictors[p].start() def f(): msg = loads(self.sim2coord_socket.recv(copy=False).bytes) sim_name = msg[0] agent_name = msg[1] def cb(outputs): try: output = outputs.result() except CancelledError: logger.info("{} cancelled.".format(sim_name)) return print('coordinator sending', sim_name.encode('utf-8'), output[0].shape) self.coord2sim_socket.send_multipart( [sim_name.encode('utf-8'), dumps(output[0])]) self.predictors[agent_name].put_task(msg[2:], cb) self.recv_thread = ShareSessionThread(LoopThread(f, pausable=False)) # self.recv_thread.daemon = True self.recv_thread.name = 'coordinator recv' self.recv_thread.start() logger.info('Coordinator started')
def __init__(self, pipe_c2s, pipe_s2c): super(SimulatorMaster, self).__init__() assert os.name != 'nt', "Doesn't support windows!" self.daemon = True self.name = 'SimulatorMaster' self.context = zmq.Context() self.c2s_socket = self.context.socket(zmq.PULL) self.c2s_socket.bind(pipe_c2s) self.c2s_socket.set_hwm(10) self.s2c_socket = self.context.socket(zmq.ROUTER) self.s2c_socket.bind(pipe_s2c) self.s2c_socket.set_hwm(10) # queueing messages to client self.send_queue = queue.Queue(maxsize=100) def f(): msg = self.send_queue.get() self.s2c_socket.send_multipart(msg, copy=False) self.send_thread = LoopThread(f) self.send_thread.daemon = True self.send_thread.start() # make sure socket get closed at the end def clean_context(soks, context): for s in soks: s.close() context.term() import atexit atexit.register(clean_context, [self.c2s_socket, self.s2c_socket], self.context)
def _create_simulator_thread(self, idx): # spawn a separate thread to run policy def populate_job_func(): exp = self._populate_job_queue.get() self._runners[idx].step(exp) th = ShareSessionThread(LoopThread(populate_job_func, pausable=False)) th.name = "SimulatorThread-{}".format(idx) return th
def get_simulator_thread(self): # spawn a separate thread to run policy def populate_job_func(): self._populate_job_queue.get() for _ in range(self.update_frequency): self._populate_exp() th = ShareSessionThread(LoopThread(populate_job_func, pausable=False)) th.name = "SimulatorThread" return th
def get_work_thread(self): def f(): msg = self.queue.get() sim_name = msg[0] if msg[1] == SimulatorManager.MSG_TYPE.LOCK and self.locked_sim is None: self.locked_sim = sim_name self.mgr2sim_socket.send_multipart( [sim_name.encode('utf-8'), dumps('lock')]) time.sleep(0.2) return if self.locked_sim is not None: if sim_name != self.locked_sim: time.sleep(0.2) self.queue.put(msg) return elif msg[1] == SimulatorManager.MSG_TYPE.UNLOCK: self.locked_sim = None self.mgr2sim_socket.send_multipart( [sim_name.encode('utf-8'), dumps('unlock')]) time.sleep(0.2) return self.cxt_switch(sim_name) # time.sleep(0.2) # print(msg[1]) if msg[1] == SimulatorManager.MSG_TYPE.SCREEN: screen = grab_screen() self.mgr2sim_socket.send_multipart( [sim_name.encode('utf-8'), dumps(screen)]) elif msg[1] == SimulatorManager.MSG_TYPE.CLICK: # print('need to click') click(msg[2][0], msg[2][1]) self.mgr2sim_socket.send_multipart( [sim_name.encode('utf-8'), dumps('click')]) work_thread = LoopThread(f, pausable=False) work_thread.name = "work thread" return work_thread
def get_simulator_thread(self): # spawn a separate thread to run policy def populate_job_func(): self._populate_job_queue.get() ############################################################################### # HITL UPDATE # as self.update_frequency = 0 during pretraining, no workers will be initialized. ############################################################################### #logger.info("update_frequency: {}".format(self.update_frequency)) for _ in range(int(self.update_frequency)): self._populate_exp() th = ShareSessionThread(LoopThread(populate_job_func, pausable=False)) th.name = "SimulatorThread" return th
def get_simulator_thread(self): # spawn a separate thread to run policy def populate_job_func(): self._populate_job_queue.get() i = 0 # synchronous training while i < self.update_frequency: if self._populate_exp(): i += 1 time.sleep(0.1) # for _ in range(self.update_frequency): # self._populate_exp() th = ShareSessionThread(LoopThread(populate_job_func, pausable=False)) th.name = "SimulatorThread" return th
def setup(self): super(MultiGPUTrainer, self).setup() self._training_aux_threads = [] self._training_aux_running = False self._training_aux_step_counter = itertools.count() from tensorpack.callbacks.hooks import CallbackToHook from tensorflow.python.training.monitored_session \ import _HookedSession as HookedSession for tidx, n in enumerate(self._train_ops_aux): auxTrainOp = self._train_ops_aux[n] logger.info("Create aux train ops {}".format(auxTrainOp._name)) if len(auxTrainOp._callbacks) > 0: auxTrainOp._sess = HookedSession( self.sess, hooks=[CallbackToHook(cb) for cb in auxTrainOp._callbacks]) else: auxTrainOp._sess = self.sess def f(op=auxTrainOp): # avoid late-binding try: op._sess.run([op._train_op ]) # TODO this won't work with StageInput except RuntimeError: # exited pass except tf.errors.CancelledError: pass except tf.errors.AbortedError: pass # next(self._training_aux_step_counter) # atomic due to GIL th = LoopThread(f) th.name = "AsyncLoopThread-{}".format(tidx) th.pause() th.start() logger.info("Start aux thread {}".format(auxTrainOp._name)) self._training_aux_threads.append(th)
class SimulatorMaster(threading.Thread): """ A base thread to communicate with all StateExchangeSimulatorProcess. It should produce action for each simulator, as well as defining callbacks when a transition or an episode is finished. """ class ClientState(object): def __init__(self): self.memory = [] # list of Experience def __init__(self, pipe_c2s, pipe_s2c): super(SimulatorMaster, self).__init__() self.daemon = True self.name = 'SimulatorMaster' self.context = zmq.Context() self.c2s_socket = self.context.socket(zmq.PULL) self.c2s_socket.bind(pipe_c2s) self.c2s_socket.set_hwm(10) self.s2c_socket = self.context.socket(zmq.ROUTER) self.s2c_socket.bind(pipe_s2c) self.s2c_socket.set_hwm(10) # queueing messages to client self.send_queue = queue.Queue(maxsize=100) def f(): msg = self.send_queue.get() self.s2c_socket.send_multipart(msg, copy=False) self.send_thread = LoopThread(f) self.send_thread.daemon = True self.send_thread.start() # make sure socket get closed at the end def clean_context(soks, context): for s in soks: s.close() context.term() import atexit atexit.register(clean_context, [self.c2s_socket, self.s2c_socket], self.context) def run(self): self.clients = defaultdict(self.ClientState) try: while True: msg = loads(self.c2s_socket.recv(copy=False).bytes) ident, state, reward, isOver = msg # TODO check history and warn about dead client client = self.clients[ident] # check if reward&isOver is valid # in the first message, only state is valid if len(client.memory) > 0: client.memory[-1].reward = reward if isOver: self._on_episode_over(ident) else: self._on_datapoint(ident) # feed state and return action self._on_state(state, ident) except zmq.ContextTerminated: logger.info("[Simulator] Context was terminated.") @abstractmethod def _on_state(self, state, ident): """response to state sent by ident. Preferrably an async call""" @abstractmethod def _on_episode_over(self, client): """ callback when the client just finished an episode. You may want to clear the client's memory in this callback. """ def _on_datapoint(self, client): """ callback when the client just finished a transition """ def __del__(self): self.context.destroy(linger=0)
class SimulatorMaster(threading.Thread): """ A base thread to communicate with all StateExchangeSimulatorProcess. It should produce action for each simulator, as well as defining callbacks when a transition or an episode is finished. """ class ClientState(object): def __init__(self): self.memory = [] # list of Experience def __init__(self, pipe_c2s, pipe_s2c): super(SimulatorMaster, self).__init__() assert os.name != 'nt', "Doesn't support windows!" self.daemon = True self.name = 'SimulatorMaster' self.context = zmq.Context() self.c2s_socket = self.context.socket(zmq.PULL) self.c2s_socket.bind(pipe_c2s) self.c2s_socket.set_hwm(10) self.s2c_socket = self.context.socket(zmq.ROUTER) self.s2c_socket.bind(pipe_s2c) self.s2c_socket.set_hwm(10) # queueing messages to client self.send_queue = queue.Queue(maxsize=100) def f(): msg = self.send_queue.get() self.s2c_socket.send_multipart(msg, copy=False) self.send_thread = LoopThread(f) self.send_thread.daemon = True self.send_thread.start() # make sure socket get closed at the end def clean_context(soks, context): for s in soks: s.close() context.term() import atexit atexit.register(clean_context, [self.c2s_socket, self.s2c_socket], self.context) def run(self): self.clients = defaultdict(self.ClientState) try: while True: msg = loads(self.c2s_socket.recv(copy=False).bytes) ident, state, reward, isOver = msg # TODO check history and warn about dead client client = self.clients[ident] # check if reward&isOver is valid # in the first message, only state is valid if len(client.memory) > 0: client.memory[-1].reward = reward if isOver: self._on_episode_over(ident) else: self._on_datapoint(ident) # feed state and return action self._on_state(state, ident) except zmq.ContextTerminated: logger.info("[Simulator] Context was terminated.") @abstractmethod def _on_state(self, state, ident): """response to state sent by ident. Preferrably an async call""" @abstractmethod def _on_episode_over(self, client): """ callback when the client just finished an episode. You may want to clear the client's memory in this callback. """ def _on_datapoint(self, client): """ callback when the client just finished a transition """ def __del__(self): self.context.destroy(linger=0)
class EvaluatorBase(Callback): def __init__( self, data_io, name, is_training=False, idx=0, batch_size=-1, sync_step=-1, # sync only in every epoch **kwargs): if not name.startswith('evaluate/'): name = 'evaluate/' + name self._name = name self._idx = idx from ..dataflow.tensor_io import TensorIO_AgentPools assert (isinstance(data_io, TensorIO_AgentPools)) self._data_io = data_io #type: TensorIO_AgentPools self._is_training = is_training assert (batch_size > 0) logger.info( "Eval: {} create, batchSize={}, is_train={}, sync_step={}".format( name, batch_size, is_training, sync_step)) assert (batch_size > 0) self._batch_size = batch_size self._kwargs = kwargs self._tensor_io = None self._pool_name = name.replace('/', '_') self._sync_step = sync_step @property def name(self): return self._name @property def batch_size(self): return self._batch_size def _setup_graph(self): model = self.trainer.model weights_train = model.getWeightsTrain() self._op_sync_weights = tf.group( *[d.assign(s) for d, s in zip(self._weights, weights_train)]) self._callbacks = [] from tensorpack.callbacks.summary import MergeAllSummaries_RunWithOp, MovingAverageSummary c_vars = tf.get_collection(self._name + '-ema_op') if len(c_vars) > 0: self._callbacks.append(MovingAverageSummary(self._name + '-ema_op')) self._callbacks.append(MergeAllSummaries_RunWithOp(0, self._name)) for c in self._callbacks: c.setup_graph(self.trainer) self._tensor_io._setup_graph() pass def getTensorIO(self, input_desc, **kwargs): if not self._tensor_io: self.get_input_tensor(input_desc, **kwargs) return self._tensor_io def get_input_tensor(self, input_desc, **kwargs): if self._tensor_io: return self._tensor_io.getInputTensors() self._input_desc = input_desc kwargs = kwargs.copy() kwargs.update(self._kwargs) self._tensor_io = self._data_io.getTensorIO( self._pool_name + '/pred', input_desc, queue_size=0, is_training=self._is_training, allow_no_full_batch=True, **kwargs) # self._tensor_io = TensorIO_AgentPool(self._name, self._datasets, input_desc, self._batch_size, queue_size = 0, is_training = self._is_training, **kwargs) return self._tensor_io.getInputTensors() def set_output_tensor(self, *outputs): assert (self._tensor_io is not None) self._tensor_io.setOutputTensors(*outputs) def set_weights(self, weights): self._weights = weights def _before_train(self): model = self.trainer.model from tensorpack.callbacks.hooks import CallbackToHook from tensorpack.train.base import ReuseSessionCreator if len(self._callbacks) > 0: self._sess = tf.train.MonitoredSession( session_creator=ReuseSessionCreator(self.trainer.sess), hooks=[CallbackToHook(cb) for cb in self._callbacks]) else: self._sess = self.trainer.sess self.trainer.sess.run(self._op_sync_weights) self._thread = LoopThread(self._run_loop) self._thread.start() # self._tensor_io._before_train() def _trigger(self): self.trainer.sess.run(self._op_sync_weights) pass def _before_run(self, ctx): if self._sync_step > 0 and self.local_step % self._sync_step == 0: return [self._op_sync_weights] return None def _after_train(self): if self._tensor_io: self._tensor_io.close() def _run_loop(self): if self._sess is None: return logger.info( "Evaluator {} thread start, fetch tensors = {}, batch = {}".format( self._name, len(self._tensor_io._output_tensors), self._batch_size)) hooked_sess = self.trainer.hooked_sess sess = self._sess try: tensor_io = self._tensor_io while not hooked_sess.should_stop(): tensor_io._loopStep(sess) # logger.info("evaluator loop") except (tf.errors.CancelledError, tf.errors.OutOfRangeError): pass except Exception as e: logger.exception("Exception in Evaluator Thread: {}".format(e)) finally: try: self._tensor_io.close() except Exception: pass logger.info("Evaluator {} Thread Exited.".format(self._name)) self._sess = None pass