Exemple #1
0
    def _setup_communication(self):
        # Under single machine setting, we create buffer object as the class attribute
        # The type of buffer should be determined by the model type
        self._buffer = self._create_buffer()

        self._stop_indicator = threading.Event()
        # create a thread for monitoring the training courses
        self._monitor = StopMonitor(
            self.executor, self._num_sampled_timesteps, self._in_queue_size,
            self._out_queue_size,
            self.config.get("scheduled_timesteps", 1000000),
            self.config.get("scheduled_global_steps",
                            1000), self._stop_indicator)
        self._monitor.start()

        # create threads for non-blocking communication
        self._actor2mem_q = queue.Queue(8)
        self._stop_actor2mem_indicator = threading.Event()
        self._actor2mem = ProxyReceiver(
            self.executor,
            self._de_in_queues[self.distributed_handler.task_index],
            self._actor2mem_q, self._stop_actor2mem_indicator)
        self._actor2mem.start()

        self._mem2learner_q = queue.Queue(8)
        self._stop_mem2learner_indicator = threading.Event()
        self._mem2learner = ProxySender(
            self._mem2learner_q,
            self.executor,
            self._en_out_queues[self.distributed_handler.task_index %
                                len(self._out_queues)],
            self.out_queue_phs,
            self._stop_mem2learner_indicator,
            send_buffer_index=self.distributed_handler.task_index)
        self._mem2learner.start()
Exemple #2
0
 def _setup_actor(self):
     super(SyncAgent, self)._setup_actor()
     if hasattr(self, "_de_actor_barrier_q_op"):
         self._learner2actor_q = queue.Queue(8)
         self._stop_learner2actor_indicator = threading.Event()
         self._learner2actor = ProxyReceiver(
             self.executor, self._de_actor_barrier_q_op,
             self._learner2actor_q, self._stop_learner2actor_indicator)
         self._learner2actor.start()
Exemple #3
0
 def _setup_communication(self):
     super(ApexAgent, self)._setup_communication()
     if self.config.get("prioritized_replay", False):
         self._learner2mem_q = mp.Queue(8)
         self._stop_learner2mem_indicator = threading.Event()
         self._learner2mem = ProxyReceiver(
             self.executor,
             self._de_update_queues[self.distributed_handler.task_index],
             self._learner2mem_q, self._stop_learner2mem_indicator)
         self._learner2mem.start()
Exemple #4
0
 def _setup_learner(self):
     # create threads for non-blocking communication
     self._receive_q = queue.Queue(8)
     self._stop_receiver_indicator = threading.Event()
     self._receiver = ProxyReceiver(
         self.executor,
         self._de_out_queues[self.distributed_handler.task_index %
                             len(self._out_queues)], self._receive_q,
         self._stop_receiver_indicator)
     self._receiver.start()
Exemple #5
0
 def _setup_actor(self):
     self._send_cost = self._put_cost = 0
     self._actor_mem_cost = [0, 0]
     self._actor2mem_q = queue.Queue(8)
     self._stop_sender_indicator = threading.Event()
     self._actor2mem = ProxySender(self._actor2mem_q, self.executor,
                                   self._en_in_queues, self.in_queue_phs,
                                   self._stop_sender_indicator)
     self._actor2mem.start()
Exemple #6
0
 def _setup_learner(self):
     super(ApexAgent, self)._setup_learner()
     if self.config.get("prioritized_replay", False):
         self._learner2mem_q = mp.Queue(8)
         self._stop_learner2mem_indicator = threading.Event()
         self._learner2mem = ProxySender(self._learner2mem_q,
                                         self.executor,
                                         self._en_update_queues,
                                         self.update_phs,
                                         self._stop_learner2mem_indicator,
                                         choose_buffer_index=True)
         self._learner2mem.start()
Exemple #7
0
class SyncAgent(ActorLearnerAgent):
    """Actors and learners  exchange data and model parameters in a synchronous way.

    For on-policy algorithms, e.g., D-PPO and ES.
    """
    def _init(self, model_config, ckpt_dir, custom_model):
        assert self.distributed_handler.num_learner_hosts == 1, "SyncAgent only support one learner currently"
        self.config[
            "batch_size"] = self.distributed_handler.num_actor_hosts * self.config[
                "sample_batch_size"]
        self._setup_sync_barrier()

        model_config = model_config or {}

        if custom_model is not None:
            assert np.any([
                issubclass(custom_model, e) for e in self._valid_model_classes
            ])
            model_class = custom_model
        else:
            model_name = model_config.get("type", self._default_model_class)
            assert model_name in self._valid_model_class_names, "{} does NOT support {} model.".format(
                self._agent_name, model_name)
            model_class = models.models[model_name]

        with self.distributed_handler.device:
            is_replica = (self.distributed_handler.job_name in ["actor"])
            self.model = model_class(self.executor.ob_ph_spec,
                                     self.executor.action_ph_spec,
                                     model_config=model_config,
                                     is_replica=is_replica)
            # get the order of elements defined in `learn_feed` which return an OrderedDict
            self._element_names = self.model.learn_feed.keys()

            if self.distributed_handler.job_name in ["actor"]:
                with tf.device("/job:{}/task:{}/cpu".format(
                        self.distributed_handler.job_name,
                        self.distributed_handler.task_index)):
                    self._behavior_model = model_class(
                        self.executor.ob_ph_spec,
                        self.executor.action_ph_spec,
                        model_config=model_config,
                        is_replica=is_replica,
                        scope='behavior')
            else:
                self._behavior_model = self.model

            self._build_communication(
                job_name=self.distributed_handler.job_name,
                task_index=self.distributed_handler.task_index)

        global_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

        if is_replica:
            global_vars = [
                v for v in global_vars
                if v not in self.behavior_model.all_variables
            ]

        self.executor.setup(self.distributed_handler.master,
                            self.distributed_handler.is_chief,
                            self.model.global_step,
                            ckpt_dir,
                            self.model.summary_ops,
                            global_vars=global_vars,
                            local_vars=self.behavior_model.all_variables
                            if is_replica else None,
                            save_var_list=[
                                var for var in global_vars if var not in [
                                    self._learner_done_flags,
                                    self._actor_done_flags, self._ready_to_exit
                                ]
                            ],
                            save_steps=10,
                            job_name=self.distributed_handler.job_name,
                            task_index=self.distributed_handler.task_index,
                            async_mode=False)

        super(SyncAgent, self)._init()

    def _create_buffer(self):
        return AggregateBuffer(self.config.get("buffer_size", 10000))

    def _setup_sync_barrier(self):
        """setup sync barrier for actors, in order to ensure the order of execution
        between actors and learner.
        """
        with self.distributed_handler.get_replica_device():
            self._global_num_sampled_per_iteration = tf.get_variable(
                name="global_num_sampled_per_iteration",
                dtype=tf.int64,
                shape=())

        self._update_global_num_sampled_per_iteration = tf.assign_add(
            self._global_num_sampled_per_iteration,
            np.int64(self.config["sample_batch_size"]),
            use_locking=True)

        self._reset_global_num_sampled_per_iteration = tf.assign(
            self._global_num_sampled_per_iteration,
            np.int64(0),
            use_locking=True)

        self._actor_barrier_q_list = []
        for i in range(self.distributed_handler.num_actor_hosts):
            with tf.device("/job:actor/task:{}".format(i)):
                self._actor_barrier_q_list.append(
                    tf.FIFOQueue(self.distributed_handler.num_actor_hosts,
                                 dtypes=[tf.bool],
                                 shapes=[()],
                                 shared_name="actor_barrier_q{}".format(i)))

        if self.distributed_handler.job_name == "learner":
            self._en_actor_barrier_q_list = [
                e.enqueue(tf.constant(True, dtype=tf.bool))
                for e in self._actor_barrier_q_list
            ]
            self._en_actor_barrier_q_op = tf.group(
                *self._en_actor_barrier_q_list)

        elif self.distributed_handler.job_name == "actor":
            self._de_actor_barrier_q_list = [
                e.dequeue_many(self.distributed_handler.num_learner_hosts)
                for e in self._actor_barrier_q_list
            ]
            self._de_actor_barrier_q_op = self._de_actor_barrier_q_list[
                self.distributed_handler.task_index]
            self._close_actor_barrier_q_list = [
                barrier_q.close(cancel_pending_enqueues=True)
                for barrier_q in self._actor_barrier_q_list
            ]

    def learn(self, batch_data):
        extra_results = super(SyncAgent, self).learn(batch_data)
        self.executor.run([
            self._en_actor_barrier_q_op,
            self._reset_global_num_sampled_per_iteration
        ], {})
        return extra_results

    def _setup_actor(self):
        super(SyncAgent, self)._setup_actor()
        if hasattr(self, "_de_actor_barrier_q_op"):
            self._learner2actor_q = queue.Queue(8)
            self._stop_learner2actor_indicator = threading.Event()
            self._learner2actor = ProxyReceiver(
                self.executor, self._de_actor_barrier_q_op,
                self._learner2actor_q, self._stop_learner2actor_indicator)
            self._learner2actor.start()

    def _should_memory_stop(self):
        if self._stop_indicator.is_set():
            # as the monitor thread has set the event
            self._monitor.join()
            should_stop, ready_to_exit = self.executor.run(
                [self._should_stop, self._ready_to_exit], {})
            if should_stop:
                # need to close the queues so that the threads that
                # execute `session.run()` would not be deadly blocked
                fetches = [self._close_in_queues, self._close_out_queues]
                self.executor.run(fetches, {})
                # Even though theses threads are running `enqueue` or
                # `dequeue` op, they won't be deadly blocked. Instead,
                # as we have closed the TF FIFOQueues, these threads
                # will throw corresponding exceptions as we expected.
                self._stop_mem2learner_indicator.set()
                self._mem2learner.join()
                self._stop_actor2mem_indicator.set()
                self._actor2mem.join()
                if hasattr(self, "_learner2mem"):
                    self._stop_learner2mem_indicator.set()
                    self._learner2mem.join()
                time.sleep(5)
                self.executor.session.close()
                logger.info("session closed.")
                return True
            if self.distributed_handler.task_index == 0 and not ready_to_exit:
                # notify actors and learners to exit first
                self.executor.run([self._set_ready_to_exit], {})
        return False

    def _should_learner_stop(self):
        ready_to_exit = self.executor.run(self._ready_to_exit, {})
        if ready_to_exit:
            if not self._stop_receiver_indicator.is_set():
                # Learner was notified at the first time.
                # Notify the receiver thread to stop but the main thread
                # continues
                self._stop_receiver_indicator.set()
            else:
                if not self._receiver.is_alive():
                    # The receiver thead has left `run()` method.
                    # Threads are allowed to be joined for more than
                    # once.
                    self._receiver.join()
                    logger.info("threads joined.")
                    # See if there still data need to be consumed
                    # As the thread (i.e., the producer) has done and
                    # consumer is this main thread itself, there is no
                    # inconsistent issue here.
                    if self._receive_q.empty():
                        if self.distributed_handler.task_index == 0:
                            # chief worker (i.e., learner_0) is responsible for exporting
                            # saved_model.
                            self.export_saved_model()
                        self.executor.run(self._set_stop_flag, {})
                        should_stop = False
                        while not should_stop:
                            should_stop = self.executor.run(
                                self._should_stop, {})
                        logger.info("all actors and learners have done.")
                        self.executor.session.close()
                        logger.info("session closed.")
                        return should_stop
        return False

    def _should_actor_stop(self):
        ready_to_exit = self.executor.run(self._ready_to_exit, {})
        if ready_to_exit and self._stop_sender_indicator.is_set():
            self._actor2mem.join()
            logger.info("actor2mem thread joined.")
            if not self._learner2actor.is_alive():
                # Menas that we have executed the following code snippet
                return True
            # close the sync barrier queue.
            self.executor.run(
                self._close_actor_barrier_q_list[
                    self.distributed_handler.task_index], {})
            self._stop_learner2actor_indicator.set()
            self._learner2actor.join()
            logger.info("learner2actor thread joined.")
            # notify the memory
            self.executor.run(self._set_stop_flag, {})
            should_stop = False
            while not should_stop:
                should_stop = self.executor.run(self._should_stop, {})
            logger.info("all actors and learners have done.")
            time.sleep(5)
            self.executor.session.close()
            logger.info("session closed.")
            return should_stop
        return False
Exemple #8
0
class ApexAgent(AsyncAgent):
    """Apex, an async actor-learner architecture. see http://arxiv.org/abs/1803.00933 for details.
    """

    _agent_name = "Apex"
    _default_model_class = "DQN"
    _valid_model_classes = [DQNModel, DDPGModel]
    _valid_model_class_names = ["DQN", "DDPG"]

    def _get_out_queue_meta(self):
        dtypes, shapes, phs = super(ApexAgent, self)._get_out_queue_meta()
        if self.config.get("prioritized_replay", False):
            ph = tf.placeholder(dtype=tf.int32,
                                shape=(self.config["batch_size"], ))
            dtypes.append(ph.dtype)
            shapes.append(ph.shape)
            phs.append(ph)
        return dtypes, shapes, phs

    def _build_communication(self, job_name, task_index):
        super(ApexAgent, self)._build_communication(job_name=job_name,
                                                    task_index=task_index)
        # DQN and DDPG need to update priorities
        self.update_phs = [
            tf.placeholder(dtype=tf.int32,
                           shape=(self.config["batch_size"], )),
            tf.placeholder(dtype=tf.float32,
                           shape=(self.config["batch_size"], ))
        ]
        if job_name in ["memory", "learner"]:
            self._update_queues = list()
            self._en_update_queues = list()
            self._de_update_queues = list()
            self._close_update_queues = list()
            for i in range(self.distributed_handler.num_memory_hosts):
                with tf.device("/job:memory/task:{}".format(i)):
                    update_q = tf.FIFOQueue(
                        8, [tf.int32, tf.float32],
                        [(self.config["batch_size"], ),
                         (self.config["batch_size"], )],
                        shared_name="updatequeue{}".format(i))
                    self._update_queues.append(update_q)
                    en_q = update_q.enqueue(self.update_phs)
                    self._en_update_queues.append(en_q)
                    de_q = update_q.dequeue()
                    self._de_update_queues.append(de_q)
                    self._close_update_queues.append(
                        update_q.close(cancel_pending_enqueues=True))

    def _setup_learner(self):
        super(ApexAgent, self)._setup_learner()
        if self.config.get("prioritized_replay", False):
            self._learner2mem_q = mp.Queue(8)
            self._stop_learner2mem_indicator = threading.Event()
            self._learner2mem = ProxySender(self._learner2mem_q,
                                            self.executor,
                                            self._en_update_queues,
                                            self.update_phs,
                                            self._stop_learner2mem_indicator,
                                            choose_buffer_index=True)
            self._learner2mem.start()

    def _setup_communication(self):
        super(ApexAgent, self)._setup_communication()
        if self.config.get("prioritized_replay", False):
            self._learner2mem_q = mp.Queue(8)
            self._stop_learner2mem_indicator = threading.Event()
            self._learner2mem = ProxyReceiver(
                self.executor,
                self._de_update_queues[self.distributed_handler.task_index],
                self._learner2mem_q, self._stop_learner2mem_indicator)
            self._learner2mem.start()

    def _create_buffer(self):
        if self.config.get("prioritized_replay", False):
            buffer = PrioritizedReplayBuffer(
                self.config["buffer_size"],
                self.config["prioritized_replay_alpha"])
        else:
            buffer = ReplayBuffer(self.config["buffer_size"])

        return buffer

    def send_experience(self,
                        obs,
                        actions,
                        rewards,
                        next_obs,
                        dones,
                        weights=None,
                        is_vectorized_env=False,
                        num_env=1):
        """Send collected experience to the memory host(s).
        """
        n_step = self._model_config.get("n_step", 1)
        gamma = self._model_config.get("gamma", 0.99)

        if is_vectorized_env:
            # unstack batch data
            obs_ = np.swapaxes(np.asarray(obs), 0, 1)
            dones_ = np.swapaxes(np.asarray(dones), 0, 1)
            rewards_ = np.swapaxes(np.asarray(rewards), 0, 1)
            actions_ = np.swapaxes(np.asarray(actions), 0, 1)
            next_obs_ = np.swapaxes(np.asarray(next_obs), 0, 1)

            if weights is not None:
                weights_ = np.swapaxes(np.asarray(weights), 0, 1)

            obs_list, actions_list, rewards_list, next_obs_list, dones_list, weights_list = list(), \
                list(), list(), list(), list(), list()
            for i in range(num_env):
                _obs, _actions, _rewards, _next_obs, _dones = n_step_adjustment(
                    obs_[i], actions_[i], rewards_[i], next_obs_[i], dones_[i],
                    gamma, n_step)
                obs_list.append(_obs)
                actions_list.append(_actions)
                rewards_list.append(_rewards)
                next_obs_list.append(_next_obs)
                dones_list.append(_dones)
                if weights is not None:
                    weights_list.append(weights_[i][:len(_obs)])

            obs_stack = np.concatenate(obs_list, axis=0)
            actions_stack = np.concatenate(actions_list, axis=0)
            rewards_stack = np.concatenate(rewards_list, axis=0)
            next_obs_stack = np.concatenate(next_obs_list, axis=0)
            dones_stack = np.concatenate(dones_list, axis=0)

            if weights is not None:
                weights_stack = np.stack(weights_list, axis=0)
            else:
                weights_stack = np.ones(len(rewards_stack), dtype=np.float32)
        else:
            obs_stack, actions_stack, rewards_stack, next_obs_stack, dones_stack = n_step_adjustment(
                obs, actions, rewards, next_obs, dones, gamma, n_step)
            weights_stack = weights or np.ones(len(rewards_stack),
                                               dtype=np.float32)
            weights_stack = np.asarray(weights_stack[:len(rewards_stack)])
        try:
            self._actor2mem_q.put([
                arr[:self.config["sample_batch_size"] * num_env] for arr in [
                    obs_stack, actions_stack, rewards_stack, next_obs_stack,
                    dones_stack, weights_stack
                ]
            ],
                                  timeout=30)
        except Queue.Full as e:
            logger.warn(
                "actor2mem thread has not sent even one batch for 30 seconds. It is necessary to increase the number of memory hosts."
            )
        finally:
            pass

        self._ready_to_send = False

        # clear the lists
        del obs[:len(obs) - n_step + 1]
        del actions[:len(actions) - n_step + 1]
        del rewards[:len(rewards) - n_step + 1]
        del next_obs[:len(next_obs) - n_step + 1]
        del dones[:len(dones) - n_step + 1]
        if weights is not None:
            del weights[:len(weights) - n_step + 1]

    def communicate(self):
        """Run this method on memory hosts

        Receive transitions from actors and add the data to replay buffers.
        Sample from the replay buffers and send the samples to learners.
        """

        if not self._actor2mem_q.empty():
            samples = self._actor2mem_q.get()
            obs, actions, rewards, next_obs, dones, weights = samples

            self._buffer.add(obs=obs,
                             actions=actions,
                             rewards=rewards,
                             next_obs=next_obs,
                             dones=dones,
                             weights=None)
            self._act_count += np.shape(rewards)[0]
            self._receive_count += np.shape(rewards)[0]
            if int(self._receive_count / 10000) > self._last_receive_record:
                self._last_receive_record = int(self._receive_count / 10000)
                self.executor.run(self._update_num_sampled_timesteps, {})

        if self._act_count >= max(0, self.config["learning_starts"]
                                  ) and not self._mem2learner_q.full():
            if isinstance(self._buffer, PrioritizedReplayBuffer):
                batch_data = self._buffer.sample(
                    self.config["batch_size"],
                    self.config["prioritized_replay_beta"])
                obs, actions, rewards, next_obs, dones, weights, indexes = batch_data["obs"], \
                    batch_data["actions"], batch_data["rewards"], batch_data["next_obs"], batch_data["dones"], \
                    batch_data["weights"], batch_data["indexes"]
                self._mem2learner_q.put(
                    [obs, actions, rewards, next_obs, dones, weights, indexes])
            else:
                batch_data = self._buffer.sample(self.config["batch_size"])
                obs, actions, rewards, next_obs, dones = batch_data["obs"], \
                    batch_data["actions"], batch_data["rewards"], batch_data["next_obs"], batch_data["dones"]
                weights = np.ones_like(rewards)
                self._mem2learner_q.put(
                    [obs, actions, rewards, next_obs, dones, weights])

        if self.config.get("prioritized_replay", False):
            while not self._learner2mem_q.empty():
                data = self._learner2mem_q.get()
                indexes, td_error = data
                new_priorities = (np.abs(td_error) + 1e-6)
                self._buffer.update_priorities(indexes, new_priorities)

    def receive_experience(self):
        """Try to receive collected experience from the memory host(s).
        """
        if self._receive_q.empty():
            return None
        # one queue for one thread and thus there is no racing case.
        # once there exists one batch, the calling of `get` method
        # won't be deadly blocked.
        samples = self._receive_q.get()
        if self.config.get("prioritized_replay", False):
            buffer_id, obs, actions, rewards, next_obs, dones, weights, indexes = samples
            return dict(buffer_id=buffer_id,
                        obs=obs,
                        actions=actions,
                        rewards=rewards,
                        next_obs=next_obs,
                        dones=dones,
                        weights=weights,
                        indexes=indexes)
        else:
            buffer_id, obs, actions, rewards, next_obs, dones, weights = samples
            return dict(buffer_id=buffer_id,
                        obs=obs,
                        actions=actions,
                        rewards=rewards,
                        next_obs=next_obs,
                        dones=dones,
                        weights=np.ones(rewards.shape, dtype=np.float32))

    def learn(self, batch_data):
        """Update upon a batch and send the td_errors to memories if needed

        Returns:
            extra_results (dict): contains the fields computed during an update.
        """
        buffer_id = batch_data.pop("buffer_id", 0)
        extra_results = super(ApexAgent, self).learn(
            batch_data, is_chief=self.distributed_handler.is_chief)
        extra_results["buffer_id"] = buffer_id

        if self.config.get("prioritized_replay",
                           False) and not self._learner2mem_q.full():
            try:
                self._learner2mem_q.put([
                    int(buffer_id), batch_data["indexes"],
                    extra_results["td_error"]
                ],
                                        timeout=30)
            except Queue.Full as e:
                logger.warn(
                    "learner2mem thread has not sent even one batch for 30 seconds. It is necessary to increase the number of memory hosts."
                )
            finally:
                pass

        return extra_results

    def _init_act_count(self):

        self._act_count = -self._model_config.get("n_step", 1) + 1
Exemple #9
0
class ActorLearnerAgent(AgentBase):
    """Actor-learner architecture.

    We regard some workers as learners and some as actors.
    Meanwhile, some ps do the job of parameter servers and some ps work for caching data.
    We declare model parameters on both the local host and the parameter servers.
    Learner push changes of model parameters to the servers.
    Actors pull the latest model parameters from the servers.
    """
    def _init(self):
        if self.distributed_handler.job_name == "memory":
            self._setup_communication()
        elif self.distributed_handler.job_name == "learner":
            self._setup_learner()
        elif self.distributed_handler.job_name == "actor":
            self._setup_actor()
        # As a distributed computing paradigm, we need complicated mechanism
        # (encapsulated in `should_stop()` method) for making decisions on
        # when to stop. Like a context object, once stopped, the instance
        # would not work normally anymore.
        self._stopped = False

    def _setup_actor(self):
        self._send_cost = self._put_cost = 0
        self._actor_mem_cost = [0, 0]
        self._actor2mem_q = queue.Queue(8)
        self._stop_sender_indicator = threading.Event()
        self._actor2mem = ProxySender(self._actor2mem_q, self.executor,
                                      self._en_in_queues, self.in_queue_phs,
                                      self._stop_sender_indicator)
        self._actor2mem.start()

    def _setup_learner(self):
        # create threads for non-blocking communication
        self._receive_q = queue.Queue(8)
        self._stop_receiver_indicator = threading.Event()
        self._receiver = ProxyReceiver(
            self.executor,
            self._de_out_queues[self.distributed_handler.task_index %
                                len(self._out_queues)], self._receive_q,
            self._stop_receiver_indicator)
        self._receiver.start()

    def _create_buffer(self):
        """create buffer according to the specific model"""
        raise NotImplementedError

    def _setup_communication(self):
        # Under single machine setting, we create buffer object as the class attribute
        # The type of buffer should be determined by the model type
        self._buffer = self._create_buffer()

        self._stop_indicator = threading.Event()
        # create a thread for monitoring the training courses
        self._monitor = StopMonitor(
            self.executor, self._num_sampled_timesteps, self._in_queue_size,
            self._out_queue_size,
            self.config.get("scheduled_timesteps", 1000000),
            self.config.get("scheduled_global_steps",
                            1000), self._stop_indicator)
        self._monitor.start()

        # create threads for non-blocking communication
        self._actor2mem_q = queue.Queue(8)
        self._stop_actor2mem_indicator = threading.Event()
        self._actor2mem = ProxyReceiver(
            self.executor,
            self._de_in_queues[self.distributed_handler.task_index],
            self._actor2mem_q, self._stop_actor2mem_indicator)
        self._actor2mem.start()

        self._mem2learner_q = queue.Queue(8)
        self._stop_mem2learner_indicator = threading.Event()
        self._mem2learner = ProxySender(
            self._mem2learner_q,
            self.executor,
            self._en_out_queues[self.distributed_handler.task_index %
                                len(self._out_queues)],
            self.out_queue_phs,
            self._stop_mem2learner_indicator,
            send_buffer_index=self.distributed_handler.task_index)
        self._mem2learner.start()

    def _get_in_queue_meta(self):
        """Determine the type, shape, and input placeholders for in_queue (actor --> memory)
        """
        phs = list()
        dtypes = list()
        shapes = list()
        num_env = self.config.get("num_env", 1)
        for name in self._element_names:
            v = self.model.learn_feed[name]
            if name in ["obs", "next_obs"]:
                ph = tf.placeholder(
                    dtype=tf.float32,
                    shape=(self.config["sample_batch_size"] * num_env, ) +
                    self.executor.flattened_ob_shape)
            else:
                ph = tf.placeholder(
                    dtype=v.dtype,
                    shape=[self.config["sample_batch_size"] * num_env] +
                    v.shape.as_list()[1:])
            phs.append(ph)
            dtypes.append(ph.dtype)
            shapes.append(ph.shape)
        return dtypes, shapes, phs

    def _get_out_queue_meta(self):
        """Determine the type, shape, and input placeholders for out_queue (memory --> learner)
        """
        phs = list()
        dtypes = list()
        shapes = list()

        # add index of memory
        mem_index_ph = tf.placeholder(dtype=tf.int32, shape=())
        phs.append(mem_index_ph)
        dtypes.append(mem_index_ph.dtype)
        shapes.append(mem_index_ph.shape)

        for name in self._element_names:
            v = self.model.learn_feed[name]
            if name in ["obs", "next_obs"]:
                ph = tf.placeholder(dtype=tf.float32,
                                    shape=(self.config["batch_size"], ) +
                                    self.executor.flattened_ob_shape)
            else:
                ph = tf.placeholder(dtype=v.dtype,
                                    shape=[self.config["batch_size"]] +
                                    v.shape.as_list()[1:])
            phs.append(ph)
            dtypes.append(ph.dtype)
            shapes.append(ph.shape)
        return dtypes, shapes, phs

    def _build_communication(self, job_name, task_index):
        """Build the subgraph for communication between actors, memories, and learners
        """
        if job_name in ["actor", "memory"]:
            # data flow: actor --> memory
            dtypes, shapes, self.in_queue_phs = self._get_in_queue_meta()
            self._in_queues = list()
            self._en_in_queues = list()
            self._de_in_queues = list()
            self._close_in_queues = list()
            for i in range(self.distributed_handler.num_memory_hosts):
                with tf.device("/job:memory/task:{}".format(i)):
                    in_q = tf.FIFOQueue(8,
                                        dtypes,
                                        shapes,
                                        shared_name="inqueue{}".format(i))
                    self._in_queues.append(in_q)
                    en_q = in_q.enqueue(self.in_queue_phs)
                    self._en_in_queues.append(en_q)
                    de_q = in_q.dequeue()
                    self._de_in_queues.append(de_q)
                    self._close_in_queues.append(
                        in_q.close(cancel_pending_enqueues=True))
            self._in_queue_size = self._in_queues[
                self.distributed_handler.task_index %
                len(self._in_queues)].size()

        # data flow: memory --> learner
        dtypes, shapes, self.out_queue_phs = self._get_out_queue_meta()
        self._out_queues = list()
        self._en_out_queues = list()
        self._de_out_queues = list()
        self._close_out_queues = list()
        if job_name == "memory":
            for i in range(self.distributed_handler.num_learner_hosts):
                with tf.device("/job:learner/task:{}".format(i)):
                    out_q = tf.FIFOQueue(8,
                                         dtypes,
                                         shapes,
                                         shared_name="outqueue{}".format(i))
                    self._out_queues.append(out_q)
                    en_q = out_q.enqueue(self.out_queue_phs)
                    self._en_out_queues.append(en_q)
                    de_q = out_q.dequeue()
                    self._de_out_queues.append(de_q)
                    self._close_out_queues.append(
                        out_q.close(cancel_pending_enqueues=True))
            self._out_queue_size = self._out_queues[
                self.distributed_handler.task_index %
                len(self._out_queues)].size()

        if job_name == "learner":
            with tf.device("/job:learner/task:{}".format(
                    self.distributed_handler.task_index)):
                out_q = tf.FIFOQueue(8,
                                     dtypes,
                                     shapes,
                                     shared_name="outqueue{}".format(
                                         self.distributed_handler.task_index))
                self._out_queues.append(out_q)
                en_q = out_q.enqueue(self.out_queue_phs)
                self._en_out_queues.append(en_q)
                de_q = out_q.dequeue()
                self._de_out_queues.append(de_q)
                self._close_out_queues.append(
                    out_q.close(cancel_pending_enqueues=True))

        # create an op for actors to obtain the latest vars
        sync_var_ops = list()
        for des, src in zip(self.behavior_model.actor_sync_variables,
                            self.model.actor_sync_variables):
            sync_var_ops.append(tf.assign(des, src))
        self._sync_var_op = tf.group(*sync_var_ops)

        # create some vars and queues for monitoring the training courses
        self._num_sampled_timesteps = tf.get_variable("num_sampled_timesteps",
                                                      dtype=tf.int64,
                                                      initializer=np.int64(0))

        self._learner_done_flags = tf.get_variable(
            "learner_done_flags",
            dtype=tf.bool,
            initializer=np.asarray(self.distributed_handler.num_learner_hosts *
                                   [False],
                                   dtype=np.bool))
        self._actor_done_flags = tf.get_variable(
            "actor_done_flags",
            dtype=tf.bool,
            initializer=np.asarray(self.distributed_handler.num_actor_hosts *
                                   [False],
                                   dtype=np.bool))
        self._should_stop = tf.logical_and(
            tf.reduce_all(self._learner_done_flags),
            tf.reduce_all(self._actor_done_flags))
        if self.distributed_handler.job_name == "learner":
            self._set_stop_flag = tf.assign(
                self._learner_done_flags[self.distributed_handler.task_index],
                np.bool(1),
                use_locking=True)
        if self.distributed_handler.job_name == "actor":
            self._set_stop_flag = tf.assign(
                self._actor_done_flags[self.distributed_handler.task_index],
                np.bool(1),
                use_locking=True)

        self._ready_to_exit = tf.get_variable("global_ready_to_exit",
                                              dtype=tf.bool,
                                              initializer=np.bool(0))
        self._set_ready_to_exit = tf.assign(self._ready_to_exit,
                                            np.bool(1),
                                            use_locking=True)

        self._update_num_sampled_timesteps = tf.assign_add(
            self._num_sampled_timesteps, np.int64(10000), use_locking=True)

    def sync_vars(self):
        """Sync with the latest vars
        """
        self.executor.run(self._sync_var_op, {})

    def join(self):
        """Call `server.join()` if the agent object serves as a parameter server.
        """

        self.distributed_handler.server.join()

    def communicate(self):
        raise NotImplementedError

    def should_stop(self):
        """Judge whether the agent should stop.
        """
        if not self._stopped:
            if self.distributed_handler.job_name == "memory":
                self._stopped = self._should_memory_stop()
            elif self.distributed_handler.job_name == "learner":
                self._stopped = self._should_learner_stop()
            elif self.distributed_handler.job_name == "actor":
                self._stopped = self._should_actor_stop()
            return self._stopped
            # ps host won't exit
        else:
            return True

    def _should_memory_stop(self):
        if self._stop_indicator.is_set():
            # as the monitor thread has set the event
            self._monitor.join()
            should_stop, ready_to_exit = self.executor.run(
                [self._should_stop, self._ready_to_exit], {})
            if should_stop:
                # need to close the queues so that the threads that
                # execute `session.run()` would not be deadly blocked
                fetches = [self._close_in_queues, self._close_out_queues]
                if hasattr(self, "_close_update_queues"):
                    fetches.append(self._close_update_queues)
                if hasattr(self, "_close_actor_barrier_q_op"):
                    fetches.append(self._close_actor_barrier_q_op)
                self.executor.run(fetches, {})
                # Even though theses threads are running `enqueue` or
                # `dequeue` op, they won't be deadly blocked. Instead,
                # as we have closed the TF FIFOQueues, these threads
                # will throw corresponding exceptions as we expected.
                self._stop_mem2learner_indicator.set()
                self._mem2learner.join()
                self._stop_actor2mem_indicator.set()
                self._actor2mem.join()
                if hasattr(self, "_learner2mem"):
                    self._stop_learner2mem_indicator.set()
                    self._learner2mem.join()
                self.executor.session.close()
                return True
            if self.distributed_handler.task_index == 0 and not ready_to_exit:
                # notify actors and learners to exit first
                self.executor.run([self._set_ready_to_exit], {})
        return False

    def _should_learner_stop(self):
        ready_to_exit = self.executor.run(self._ready_to_exit, {})
        if ready_to_exit:
            self._stop_receiver_indicator.set()
            self._receiver.join()
            if hasattr(self, "_learner2mem"):
                self._stop_learner2mem_indicator.set()
                self._learner2mem.join()
            logger.info("threads joined.")
            if self.distributed_handler.task_index == 0:
                # chief worker (i.e., learner_0) is responsible for exporting
                # saved_model.
                self.export_saved_model()
            self.executor.run(self._set_stop_flag, {})
            should_stop = False
            while not should_stop:
                should_stop = self.executor.run(self._should_stop, {})
            logger.info("all actors and learners have done.")
            if self.distributed_handler.is_chief:
                time.sleep(30)
            self.executor.session.close()
            logger.info("session closed.")
            return should_stop
        return False

    def _should_actor_stop(self):
        ready_to_exit = self.executor.run(self._ready_to_exit, {})
        if ready_to_exit:
            self._stop_sender_indicator.set()
            # If the Queue is full, the thread must not enter its `if`
            # branch and thus will exit `run()` immediately. Otherwise,
            # as memory hosts have not stopped, the enqueue op would not
            # be blocked.
            self._actor2mem.join()
            logger.info("thread joined.")
            # notify the memory
            self.executor.run(self._set_stop_flag, {})
            should_stop = False
            while not should_stop:
                should_stop = self.executor.run(self._should_stop, {})
            logger.info("all actors and learners have done.")
            self.executor.session.close()
            logger.info("session closed.")
            return should_stop
        return False