class ApexAgent(AsyncAgent): """Apex, an async actor-learner architecture. see http://arxiv.org/abs/1803.00933 for details. """ _agent_name = "Apex" _default_model_class = "DQN" _valid_model_classes = [DQNModel, DDPGModel] _valid_model_class_names = ["DQN", "DDPG"] def _get_out_queue_meta(self): dtypes, shapes, phs = super(ApexAgent, self)._get_out_queue_meta() if self.config.get("prioritized_replay", False): ph = tf.placeholder(dtype=tf.int32, shape=(self.config["batch_size"], )) dtypes.append(ph.dtype) shapes.append(ph.shape) phs.append(ph) return dtypes, shapes, phs def _build_communication(self, job_name, task_index): super(ApexAgent, self)._build_communication(job_name=job_name, task_index=task_index) # DQN and DDPG need to update priorities self.update_phs = [ tf.placeholder(dtype=tf.int32, shape=(self.config["batch_size"], )), tf.placeholder(dtype=tf.float32, shape=(self.config["batch_size"], )) ] if job_name in ["memory", "learner"]: self._update_queues = list() self._en_update_queues = list() self._de_update_queues = list() self._close_update_queues = list() for i in range(self.distributed_handler.num_memory_hosts): with tf.device("/job:memory/task:{}".format(i)): update_q = tf.FIFOQueue( 8, [tf.int32, tf.float32], [(self.config["batch_size"], ), (self.config["batch_size"], )], shared_name="updatequeue{}".format(i)) self._update_queues.append(update_q) en_q = update_q.enqueue(self.update_phs) self._en_update_queues.append(en_q) de_q = update_q.dequeue() self._de_update_queues.append(de_q) self._close_update_queues.append( update_q.close(cancel_pending_enqueues=True)) def _setup_learner(self): super(ApexAgent, self)._setup_learner() if self.config.get("prioritized_replay", False): self._learner2mem_q = mp.Queue(8) self._stop_learner2mem_indicator = threading.Event() self._learner2mem = ProxySender(self._learner2mem_q, self.executor, self._en_update_queues, self.update_phs, self._stop_learner2mem_indicator, choose_buffer_index=True) self._learner2mem.start() def _setup_communication(self): super(ApexAgent, self)._setup_communication() if self.config.get("prioritized_replay", False): self._learner2mem_q = mp.Queue(8) self._stop_learner2mem_indicator = threading.Event() self._learner2mem = ProxyReceiver( self.executor, self._de_update_queues[self.distributed_handler.task_index], self._learner2mem_q, self._stop_learner2mem_indicator) self._learner2mem.start() def _create_buffer(self): if self.config.get("prioritized_replay", False): buffer = PrioritizedReplayBuffer( self.config["buffer_size"], self.config["prioritized_replay_alpha"]) else: buffer = ReplayBuffer(self.config["buffer_size"]) return buffer def send_experience(self, obs, actions, rewards, next_obs, dones, weights=None, is_vectorized_env=False, num_env=1): """Send collected experience to the memory host(s). """ n_step = self._model_config.get("n_step", 1) gamma = self._model_config.get("gamma", 0.99) if is_vectorized_env: # unstack batch data obs_ = np.swapaxes(np.asarray(obs), 0, 1) dones_ = np.swapaxes(np.asarray(dones), 0, 1) rewards_ = np.swapaxes(np.asarray(rewards), 0, 1) actions_ = np.swapaxes(np.asarray(actions), 0, 1) next_obs_ = np.swapaxes(np.asarray(next_obs), 0, 1) if weights is not None: weights_ = np.swapaxes(np.asarray(weights), 0, 1) obs_list, actions_list, rewards_list, next_obs_list, dones_list, weights_list = list(), \ list(), list(), list(), list(), list() for i in range(num_env): _obs, _actions, _rewards, _next_obs, _dones = n_step_adjustment( obs_[i], actions_[i], rewards_[i], next_obs_[i], dones_[i], gamma, n_step) obs_list.append(_obs) actions_list.append(_actions) rewards_list.append(_rewards) next_obs_list.append(_next_obs) dones_list.append(_dones) if weights is not None: weights_list.append(weights_[i][:len(_obs)]) obs_stack = np.concatenate(obs_list, axis=0) actions_stack = np.concatenate(actions_list, axis=0) rewards_stack = np.concatenate(rewards_list, axis=0) next_obs_stack = np.concatenate(next_obs_list, axis=0) dones_stack = np.concatenate(dones_list, axis=0) if weights is not None: weights_stack = np.stack(weights_list, axis=0) else: weights_stack = np.ones(len(rewards_stack), dtype=np.float32) else: obs_stack, actions_stack, rewards_stack, next_obs_stack, dones_stack = n_step_adjustment( obs, actions, rewards, next_obs, dones, gamma, n_step) weights_stack = weights or np.ones(len(rewards_stack), dtype=np.float32) weights_stack = np.asarray(weights_stack[:len(rewards_stack)]) try: self._actor2mem_q.put([ arr[:self.config["sample_batch_size"] * num_env] for arr in [ obs_stack, actions_stack, rewards_stack, next_obs_stack, dones_stack, weights_stack ] ], timeout=30) except Queue.Full as e: logger.warn( "actor2mem thread has not sent even one batch for 30 seconds. It is necessary to increase the number of memory hosts." ) finally: pass self._ready_to_send = False # clear the lists del obs[:len(obs) - n_step + 1] del actions[:len(actions) - n_step + 1] del rewards[:len(rewards) - n_step + 1] del next_obs[:len(next_obs) - n_step + 1] del dones[:len(dones) - n_step + 1] if weights is not None: del weights[:len(weights) - n_step + 1] def communicate(self): """Run this method on memory hosts Receive transitions from actors and add the data to replay buffers. Sample from the replay buffers and send the samples to learners. """ if not self._actor2mem_q.empty(): samples = self._actor2mem_q.get() obs, actions, rewards, next_obs, dones, weights = samples self._buffer.add(obs=obs, actions=actions, rewards=rewards, next_obs=next_obs, dones=dones, weights=None) self._act_count += np.shape(rewards)[0] self._receive_count += np.shape(rewards)[0] if int(self._receive_count / 10000) > self._last_receive_record: self._last_receive_record = int(self._receive_count / 10000) self.executor.run(self._update_num_sampled_timesteps, {}) if self._act_count >= max(0, self.config["learning_starts"] ) and not self._mem2learner_q.full(): if isinstance(self._buffer, PrioritizedReplayBuffer): batch_data = self._buffer.sample( self.config["batch_size"], self.config["prioritized_replay_beta"]) obs, actions, rewards, next_obs, dones, weights, indexes = batch_data["obs"], \ batch_data["actions"], batch_data["rewards"], batch_data["next_obs"], batch_data["dones"], \ batch_data["weights"], batch_data["indexes"] self._mem2learner_q.put( [obs, actions, rewards, next_obs, dones, weights, indexes]) else: batch_data = self._buffer.sample(self.config["batch_size"]) obs, actions, rewards, next_obs, dones = batch_data["obs"], \ batch_data["actions"], batch_data["rewards"], batch_data["next_obs"], batch_data["dones"] weights = np.ones_like(rewards) self._mem2learner_q.put( [obs, actions, rewards, next_obs, dones, weights]) if self.config.get("prioritized_replay", False): while not self._learner2mem_q.empty(): data = self._learner2mem_q.get() indexes, td_error = data new_priorities = (np.abs(td_error) + 1e-6) self._buffer.update_priorities(indexes, new_priorities) def receive_experience(self): """Try to receive collected experience from the memory host(s). """ if self._receive_q.empty(): return None # one queue for one thread and thus there is no racing case. # once there exists one batch, the calling of `get` method # won't be deadly blocked. samples = self._receive_q.get() if self.config.get("prioritized_replay", False): buffer_id, obs, actions, rewards, next_obs, dones, weights, indexes = samples return dict(buffer_id=buffer_id, obs=obs, actions=actions, rewards=rewards, next_obs=next_obs, dones=dones, weights=weights, indexes=indexes) else: buffer_id, obs, actions, rewards, next_obs, dones, weights = samples return dict(buffer_id=buffer_id, obs=obs, actions=actions, rewards=rewards, next_obs=next_obs, dones=dones, weights=np.ones(rewards.shape, dtype=np.float32)) def learn(self, batch_data): """Update upon a batch and send the td_errors to memories if needed Returns: extra_results (dict): contains the fields computed during an update. """ buffer_id = batch_data.pop("buffer_id", 0) extra_results = super(ApexAgent, self).learn( batch_data, is_chief=self.distributed_handler.is_chief) extra_results["buffer_id"] = buffer_id if self.config.get("prioritized_replay", False) and not self._learner2mem_q.full(): try: self._learner2mem_q.put([ int(buffer_id), batch_data["indexes"], extra_results["td_error"] ], timeout=30) except Queue.Full as e: logger.warn( "learner2mem thread has not sent even one batch for 30 seconds. It is necessary to increase the number of memory hosts." ) finally: pass return extra_results def _init_act_count(self): self._act_count = -self._model_config.get("n_step", 1) + 1
class SyncAgent(ActorLearnerAgent): """Actors and learners exchange data and model parameters in a synchronous way. For on-policy algorithms, e.g., D-PPO and ES. """ def _init(self, model_config, ckpt_dir, custom_model): assert self.distributed_handler.num_learner_hosts == 1, "SyncAgent only support one learner currently" self.config[ "batch_size"] = self.distributed_handler.num_actor_hosts * self.config[ "sample_batch_size"] self._setup_sync_barrier() model_config = model_config or {} if custom_model is not None: assert np.any([ issubclass(custom_model, e) for e in self._valid_model_classes ]) model_class = custom_model else: model_name = model_config.get("type", self._default_model_class) assert model_name in self._valid_model_class_names, "{} does NOT support {} model.".format( self._agent_name, model_name) model_class = models.models[model_name] with self.distributed_handler.device: is_replica = (self.distributed_handler.job_name in ["actor"]) self.model = model_class(self.executor.ob_ph_spec, self.executor.action_ph_spec, model_config=model_config, is_replica=is_replica) # get the order of elements defined in `learn_feed` which return an OrderedDict self._element_names = self.model.learn_feed.keys() if self.distributed_handler.job_name in ["actor"]: with tf.device("/job:{}/task:{}/cpu".format( self.distributed_handler.job_name, self.distributed_handler.task_index)): self._behavior_model = model_class( self.executor.ob_ph_spec, self.executor.action_ph_spec, model_config=model_config, is_replica=is_replica, scope='behavior') else: self._behavior_model = self.model self._build_communication( job_name=self.distributed_handler.job_name, task_index=self.distributed_handler.task_index) global_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if is_replica: global_vars = [ v for v in global_vars if v not in self.behavior_model.all_variables ] self.executor.setup(self.distributed_handler.master, self.distributed_handler.is_chief, self.model.global_step, ckpt_dir, self.model.summary_ops, global_vars=global_vars, local_vars=self.behavior_model.all_variables if is_replica else None, save_var_list=[ var for var in global_vars if var not in [ self._learner_done_flags, self._actor_done_flags, self._ready_to_exit ] ], save_steps=10, job_name=self.distributed_handler.job_name, task_index=self.distributed_handler.task_index, async_mode=False) super(SyncAgent, self)._init() def _create_buffer(self): return AggregateBuffer(self.config.get("buffer_size", 10000)) def _setup_sync_barrier(self): """setup sync barrier for actors, in order to ensure the order of execution between actors and learner. """ with self.distributed_handler.get_replica_device(): self._global_num_sampled_per_iteration = tf.get_variable( name="global_num_sampled_per_iteration", dtype=tf.int64, shape=()) self._update_global_num_sampled_per_iteration = tf.assign_add( self._global_num_sampled_per_iteration, np.int64(self.config["sample_batch_size"]), use_locking=True) self._reset_global_num_sampled_per_iteration = tf.assign( self._global_num_sampled_per_iteration, np.int64(0), use_locking=True) self._actor_barrier_q_list = [] for i in range(self.distributed_handler.num_actor_hosts): with tf.device("/job:actor/task:{}".format(i)): self._actor_barrier_q_list.append( tf.FIFOQueue(self.distributed_handler.num_actor_hosts, dtypes=[tf.bool], shapes=[()], shared_name="actor_barrier_q{}".format(i))) if self.distributed_handler.job_name == "learner": self._en_actor_barrier_q_list = [ e.enqueue(tf.constant(True, dtype=tf.bool)) for e in self._actor_barrier_q_list ] self._en_actor_barrier_q_op = tf.group( *self._en_actor_barrier_q_list) elif self.distributed_handler.job_name == "actor": self._de_actor_barrier_q_list = [ e.dequeue_many(self.distributed_handler.num_learner_hosts) for e in self._actor_barrier_q_list ] self._de_actor_barrier_q_op = self._de_actor_barrier_q_list[ self.distributed_handler.task_index] self._close_actor_barrier_q_list = [ barrier_q.close(cancel_pending_enqueues=True) for barrier_q in self._actor_barrier_q_list ] def learn(self, batch_data): extra_results = super(SyncAgent, self).learn(batch_data) self.executor.run([ self._en_actor_barrier_q_op, self._reset_global_num_sampled_per_iteration ], {}) return extra_results def _setup_actor(self): super(SyncAgent, self)._setup_actor() if hasattr(self, "_de_actor_barrier_q_op"): self._learner2actor_q = queue.Queue(8) self._stop_learner2actor_indicator = threading.Event() self._learner2actor = ProxyReceiver( self.executor, self._de_actor_barrier_q_op, self._learner2actor_q, self._stop_learner2actor_indicator) self._learner2actor.start() def _should_memory_stop(self): if self._stop_indicator.is_set(): # as the monitor thread has set the event self._monitor.join() should_stop, ready_to_exit = self.executor.run( [self._should_stop, self._ready_to_exit], {}) if should_stop: # need to close the queues so that the threads that # execute `session.run()` would not be deadly blocked fetches = [self._close_in_queues, self._close_out_queues] self.executor.run(fetches, {}) # Even though theses threads are running `enqueue` or # `dequeue` op, they won't be deadly blocked. Instead, # as we have closed the TF FIFOQueues, these threads # will throw corresponding exceptions as we expected. self._stop_mem2learner_indicator.set() self._mem2learner.join() self._stop_actor2mem_indicator.set() self._actor2mem.join() if hasattr(self, "_learner2mem"): self._stop_learner2mem_indicator.set() self._learner2mem.join() time.sleep(5) self.executor.session.close() logger.info("session closed.") return True if self.distributed_handler.task_index == 0 and not ready_to_exit: # notify actors and learners to exit first self.executor.run([self._set_ready_to_exit], {}) return False def _should_learner_stop(self): ready_to_exit = self.executor.run(self._ready_to_exit, {}) if ready_to_exit: if not self._stop_receiver_indicator.is_set(): # Learner was notified at the first time. # Notify the receiver thread to stop but the main thread # continues self._stop_receiver_indicator.set() else: if not self._receiver.is_alive(): # The receiver thead has left `run()` method. # Threads are allowed to be joined for more than # once. self._receiver.join() logger.info("threads joined.") # See if there still data need to be consumed # As the thread (i.e., the producer) has done and # consumer is this main thread itself, there is no # inconsistent issue here. if self._receive_q.empty(): if self.distributed_handler.task_index == 0: # chief worker (i.e., learner_0) is responsible for exporting # saved_model. self.export_saved_model() self.executor.run(self._set_stop_flag, {}) should_stop = False while not should_stop: should_stop = self.executor.run( self._should_stop, {}) logger.info("all actors and learners have done.") self.executor.session.close() logger.info("session closed.") return should_stop return False def _should_actor_stop(self): ready_to_exit = self.executor.run(self._ready_to_exit, {}) if ready_to_exit and self._stop_sender_indicator.is_set(): self._actor2mem.join() logger.info("actor2mem thread joined.") if not self._learner2actor.is_alive(): # Menas that we have executed the following code snippet return True # close the sync barrier queue. self.executor.run( self._close_actor_barrier_q_list[ self.distributed_handler.task_index], {}) self._stop_learner2actor_indicator.set() self._learner2actor.join() logger.info("learner2actor thread joined.") # notify the memory self.executor.run(self._set_stop_flag, {}) should_stop = False while not should_stop: should_stop = self.executor.run(self._should_stop, {}) logger.info("all actors and learners have done.") time.sleep(5) self.executor.session.close() logger.info("session closed.") return should_stop return False
class ActorLearnerAgent(AgentBase): """Actor-learner architecture. We regard some workers as learners and some as actors. Meanwhile, some ps do the job of parameter servers and some ps work for caching data. We declare model parameters on both the local host and the parameter servers. Learner push changes of model parameters to the servers. Actors pull the latest model parameters from the servers. """ def _init(self): if self.distributed_handler.job_name == "memory": self._setup_communication() elif self.distributed_handler.job_name == "learner": self._setup_learner() elif self.distributed_handler.job_name == "actor": self._setup_actor() # As a distributed computing paradigm, we need complicated mechanism # (encapsulated in `should_stop()` method) for making decisions on # when to stop. Like a context object, once stopped, the instance # would not work normally anymore. self._stopped = False def _setup_actor(self): self._send_cost = self._put_cost = 0 self._actor_mem_cost = [0, 0] self._actor2mem_q = queue.Queue(8) self._stop_sender_indicator = threading.Event() self._actor2mem = ProxySender(self._actor2mem_q, self.executor, self._en_in_queues, self.in_queue_phs, self._stop_sender_indicator) self._actor2mem.start() def _setup_learner(self): # create threads for non-blocking communication self._receive_q = queue.Queue(8) self._stop_receiver_indicator = threading.Event() self._receiver = ProxyReceiver( self.executor, self._de_out_queues[self.distributed_handler.task_index % len(self._out_queues)], self._receive_q, self._stop_receiver_indicator) self._receiver.start() def _create_buffer(self): """create buffer according to the specific model""" raise NotImplementedError def _setup_communication(self): # Under single machine setting, we create buffer object as the class attribute # The type of buffer should be determined by the model type self._buffer = self._create_buffer() self._stop_indicator = threading.Event() # create a thread for monitoring the training courses self._monitor = StopMonitor( self.executor, self._num_sampled_timesteps, self._in_queue_size, self._out_queue_size, self.config.get("scheduled_timesteps", 1000000), self.config.get("scheduled_global_steps", 1000), self._stop_indicator) self._monitor.start() # create threads for non-blocking communication self._actor2mem_q = queue.Queue(8) self._stop_actor2mem_indicator = threading.Event() self._actor2mem = ProxyReceiver( self.executor, self._de_in_queues[self.distributed_handler.task_index], self._actor2mem_q, self._stop_actor2mem_indicator) self._actor2mem.start() self._mem2learner_q = queue.Queue(8) self._stop_mem2learner_indicator = threading.Event() self._mem2learner = ProxySender( self._mem2learner_q, self.executor, self._en_out_queues[self.distributed_handler.task_index % len(self._out_queues)], self.out_queue_phs, self._stop_mem2learner_indicator, send_buffer_index=self.distributed_handler.task_index) self._mem2learner.start() def _get_in_queue_meta(self): """Determine the type, shape, and input placeholders for in_queue (actor --> memory) """ phs = list() dtypes = list() shapes = list() num_env = self.config.get("num_env", 1) for name in self._element_names: v = self.model.learn_feed[name] if name in ["obs", "next_obs"]: ph = tf.placeholder( dtype=tf.float32, shape=(self.config["sample_batch_size"] * num_env, ) + self.executor.flattened_ob_shape) else: ph = tf.placeholder( dtype=v.dtype, shape=[self.config["sample_batch_size"] * num_env] + v.shape.as_list()[1:]) phs.append(ph) dtypes.append(ph.dtype) shapes.append(ph.shape) return dtypes, shapes, phs def _get_out_queue_meta(self): """Determine the type, shape, and input placeholders for out_queue (memory --> learner) """ phs = list() dtypes = list() shapes = list() # add index of memory mem_index_ph = tf.placeholder(dtype=tf.int32, shape=()) phs.append(mem_index_ph) dtypes.append(mem_index_ph.dtype) shapes.append(mem_index_ph.shape) for name in self._element_names: v = self.model.learn_feed[name] if name in ["obs", "next_obs"]: ph = tf.placeholder(dtype=tf.float32, shape=(self.config["batch_size"], ) + self.executor.flattened_ob_shape) else: ph = tf.placeholder(dtype=v.dtype, shape=[self.config["batch_size"]] + v.shape.as_list()[1:]) phs.append(ph) dtypes.append(ph.dtype) shapes.append(ph.shape) return dtypes, shapes, phs def _build_communication(self, job_name, task_index): """Build the subgraph for communication between actors, memories, and learners """ if job_name in ["actor", "memory"]: # data flow: actor --> memory dtypes, shapes, self.in_queue_phs = self._get_in_queue_meta() self._in_queues = list() self._en_in_queues = list() self._de_in_queues = list() self._close_in_queues = list() for i in range(self.distributed_handler.num_memory_hosts): with tf.device("/job:memory/task:{}".format(i)): in_q = tf.FIFOQueue(8, dtypes, shapes, shared_name="inqueue{}".format(i)) self._in_queues.append(in_q) en_q = in_q.enqueue(self.in_queue_phs) self._en_in_queues.append(en_q) de_q = in_q.dequeue() self._de_in_queues.append(de_q) self._close_in_queues.append( in_q.close(cancel_pending_enqueues=True)) self._in_queue_size = self._in_queues[ self.distributed_handler.task_index % len(self._in_queues)].size() # data flow: memory --> learner dtypes, shapes, self.out_queue_phs = self._get_out_queue_meta() self._out_queues = list() self._en_out_queues = list() self._de_out_queues = list() self._close_out_queues = list() if job_name == "memory": for i in range(self.distributed_handler.num_learner_hosts): with tf.device("/job:learner/task:{}".format(i)): out_q = tf.FIFOQueue(8, dtypes, shapes, shared_name="outqueue{}".format(i)) self._out_queues.append(out_q) en_q = out_q.enqueue(self.out_queue_phs) self._en_out_queues.append(en_q) de_q = out_q.dequeue() self._de_out_queues.append(de_q) self._close_out_queues.append( out_q.close(cancel_pending_enqueues=True)) self._out_queue_size = self._out_queues[ self.distributed_handler.task_index % len(self._out_queues)].size() if job_name == "learner": with tf.device("/job:learner/task:{}".format( self.distributed_handler.task_index)): out_q = tf.FIFOQueue(8, dtypes, shapes, shared_name="outqueue{}".format( self.distributed_handler.task_index)) self._out_queues.append(out_q) en_q = out_q.enqueue(self.out_queue_phs) self._en_out_queues.append(en_q) de_q = out_q.dequeue() self._de_out_queues.append(de_q) self._close_out_queues.append( out_q.close(cancel_pending_enqueues=True)) # create an op for actors to obtain the latest vars sync_var_ops = list() for des, src in zip(self.behavior_model.actor_sync_variables, self.model.actor_sync_variables): sync_var_ops.append(tf.assign(des, src)) self._sync_var_op = tf.group(*sync_var_ops) # create some vars and queues for monitoring the training courses self._num_sampled_timesteps = tf.get_variable("num_sampled_timesteps", dtype=tf.int64, initializer=np.int64(0)) self._learner_done_flags = tf.get_variable( "learner_done_flags", dtype=tf.bool, initializer=np.asarray(self.distributed_handler.num_learner_hosts * [False], dtype=np.bool)) self._actor_done_flags = tf.get_variable( "actor_done_flags", dtype=tf.bool, initializer=np.asarray(self.distributed_handler.num_actor_hosts * [False], dtype=np.bool)) self._should_stop = tf.logical_and( tf.reduce_all(self._learner_done_flags), tf.reduce_all(self._actor_done_flags)) if self.distributed_handler.job_name == "learner": self._set_stop_flag = tf.assign( self._learner_done_flags[self.distributed_handler.task_index], np.bool(1), use_locking=True) if self.distributed_handler.job_name == "actor": self._set_stop_flag = tf.assign( self._actor_done_flags[self.distributed_handler.task_index], np.bool(1), use_locking=True) self._ready_to_exit = tf.get_variable("global_ready_to_exit", dtype=tf.bool, initializer=np.bool(0)) self._set_ready_to_exit = tf.assign(self._ready_to_exit, np.bool(1), use_locking=True) self._update_num_sampled_timesteps = tf.assign_add( self._num_sampled_timesteps, np.int64(10000), use_locking=True) def sync_vars(self): """Sync with the latest vars """ self.executor.run(self._sync_var_op, {}) def join(self): """Call `server.join()` if the agent object serves as a parameter server. """ self.distributed_handler.server.join() def communicate(self): raise NotImplementedError def should_stop(self): """Judge whether the agent should stop. """ if not self._stopped: if self.distributed_handler.job_name == "memory": self._stopped = self._should_memory_stop() elif self.distributed_handler.job_name == "learner": self._stopped = self._should_learner_stop() elif self.distributed_handler.job_name == "actor": self._stopped = self._should_actor_stop() return self._stopped # ps host won't exit else: return True def _should_memory_stop(self): if self._stop_indicator.is_set(): # as the monitor thread has set the event self._monitor.join() should_stop, ready_to_exit = self.executor.run( [self._should_stop, self._ready_to_exit], {}) if should_stop: # need to close the queues so that the threads that # execute `session.run()` would not be deadly blocked fetches = [self._close_in_queues, self._close_out_queues] if hasattr(self, "_close_update_queues"): fetches.append(self._close_update_queues) if hasattr(self, "_close_actor_barrier_q_op"): fetches.append(self._close_actor_barrier_q_op) self.executor.run(fetches, {}) # Even though theses threads are running `enqueue` or # `dequeue` op, they won't be deadly blocked. Instead, # as we have closed the TF FIFOQueues, these threads # will throw corresponding exceptions as we expected. self._stop_mem2learner_indicator.set() self._mem2learner.join() self._stop_actor2mem_indicator.set() self._actor2mem.join() if hasattr(self, "_learner2mem"): self._stop_learner2mem_indicator.set() self._learner2mem.join() self.executor.session.close() return True if self.distributed_handler.task_index == 0 and not ready_to_exit: # notify actors and learners to exit first self.executor.run([self._set_ready_to_exit], {}) return False def _should_learner_stop(self): ready_to_exit = self.executor.run(self._ready_to_exit, {}) if ready_to_exit: self._stop_receiver_indicator.set() self._receiver.join() if hasattr(self, "_learner2mem"): self._stop_learner2mem_indicator.set() self._learner2mem.join() logger.info("threads joined.") if self.distributed_handler.task_index == 0: # chief worker (i.e., learner_0) is responsible for exporting # saved_model. self.export_saved_model() self.executor.run(self._set_stop_flag, {}) should_stop = False while not should_stop: should_stop = self.executor.run(self._should_stop, {}) logger.info("all actors and learners have done.") if self.distributed_handler.is_chief: time.sleep(30) self.executor.session.close() logger.info("session closed.") return should_stop return False def _should_actor_stop(self): ready_to_exit = self.executor.run(self._ready_to_exit, {}) if ready_to_exit: self._stop_sender_indicator.set() # If the Queue is full, the thread must not enter its `if` # branch and thus will exit `run()` immediately. Otherwise, # as memory hosts have not stopped, the enqueue op would not # be blocked. self._actor2mem.join() logger.info("thread joined.") # notify the memory self.executor.run(self._set_stop_flag, {}) should_stop = False while not should_stop: should_stop = self.executor.run(self._should_stop, {}) logger.info("all actors and learners have done.") self.executor.session.close() logger.info("session closed.") return should_stop return False