示例#1
0
class AsyncOffPolicyDriver(OffPolicyDriver):
    """
    A driver that enables async training. The definition of 'async' is that
    prediction and learning are not intervened like below (synchronous):
        pred -> learn -> pred -> learn -> ...
    Instead they are decoupled:
        pred -> pred -> pred -> ...
                   |
            (synchronize periodically)
                   |
        learn -> learn -> learn -> ...
    And more importantly, a learner or predictor may only operate on a subset of
    environments at a time.
    """

    def __init__(self,
                 envs,
                 algorithm: OffPolicyAlgorithm,
                 num_actor_queues=1,
                 unroll_length=8,
                 learn_queue_cap=1,
                 actor_queue_cap=1,
                 observers=[],
                 metrics=[],
                 exp_replayer="one_time"):
        """
        Args:
            envs (list[TFEnvironment]):  list of TFEnvironment
            algorithm (OffPolicyAlgorithm):
            num_actor_queues (int): number of actor queues. Each queue is
                exclusively owned by just one actor thread.
            unroll_length (int): number of time steps each environment proceeds
                before sending the steps to the learner queue
            learn_queue_cap (int): the learner queue capacity determines how many
                environments contribute to the training data for each training
                iteration
            actor_queue_cap (int): the actor queue capacity determines how many
                environments contribute to the data for each prediction forward
                in an `ActorThread`. To prevent deadlock, it's required that
                `actor_queue_cap` * `num_actor_queues` <= `num_envs`.
            observers (list[Callable]): An optional list of observers that are
                updated after every step in the environment. Each observer is a
                callable(time_step.Trajectory).
            metrics (list[TFStepMetric]): An optional list of metrics.
            exp_replayer (str): a string that indicates which ExperienceReplayer
                to use.
        """
        super(AsyncOffPolicyDriver, self).__init__(
            env=envs[0],
            num_envs=len(envs),
            algorithm=algorithm,
            exp_replayer=exp_replayer,
            observers=observers,
            metrics=metrics)

        # create threads
        self._coord = tf.train.Coordinator()
        num_envs = len(envs)
        policy_step_spec = PolicyStep(
            action=algorithm.action_spec,
            state=algorithm.train_state_spec,
            info=algorithm.rollout_info_spec)
        self._tfq = TFQueues(
            num_envs,
            self._env.batch_size,
            learn_queue_cap,
            actor_queue_cap,
            time_step_spec=algorithm.time_step_spec,
            policy_step_spec=policy_step_spec,
            unroll_length=unroll_length,
            num_actor_queues=num_actor_queues)
        actor_threads = [
            ActorThread(
                name="actor{}".format(i),
                coord=self._coord,
                algorithm=self._algorithm,
                tf_queues=self._tfq,
                id=i) for i in range(num_actor_queues)
        ]
        env_threads = [
            EnvThread(
                name="env{}".format(i),
                coord=self._coord,
                env=envs[i],
                tf_queues=self._tfq,
                unroll_length=unroll_length,
                id=i,
                actor_id=i % num_actor_queues) for i in range(num_envs)
        ]
        self._log_thread = LogThread(
            name="logging",
            num_envs=num_envs,
            env_batch_size=self._env.batch_size,
            observers=observers,
            metrics=metrics,
            coord=self._coord,
            queue=self._tfq.log_queue)
        self._threads = actor_threads + env_threads + [self._log_thread]
        algorithm.set_metrics(self.get_metrics())

    def get_step_metrics(self):
        """See PolicyDriver.get_step_metrics()"""
        return self._log_thread.metrics[:2]

    def get_metrics(self):
        """See PolicyDriver.get_metrics()"""
        return self._log_thread.metrics

    def start(self):
        """Starts all env, actor, and log threads."""
        for th in self._threads:
            th.setDaemon(True)
            th.start()
        logging.info("All threads started")

    @tf.function
    def get_training_exps(self):
        """
        Get training experiences from the learning queue

        Returns:
            exp (Experience): shapes are [Q, T, B, ...], where Q is
                learn_queue_cap
            env_id (tf.tensor): if not None, has the shape of (`num_envs`). Each
                element of `env_ids` indicates which batched env the data come from.
            steps (int): how many environment steps this batch of exps contain
        """
        batch = self._tfq.learn_queue.dequeue_all()
        # convert the batch to the experience format
        exp = make_experience(batch.time_step, batch.policy_step, batch.state)
        # make the exp batch major for each environment
        num_envs, unroll_length, env_batch_size \
            = batch.time_step.reward.shape[:3]
        steps = num_envs * unroll_length * env_batch_size
        return exp, steps

    def run_async(self):
        """
        Each call of run_async() will wait for a learning batch to be filled in
        by the env threads.
        Running in the eager mode. The reason is that currently
        OnetimeExperienceReplayer is incompatible with Graph mode because it
        replays by a temporary variable.

        Output:
            steps (int): the total number of unrolled steps
        """
        exp, steps = self.get_training_exps()
        self._algorithm.observe(exp)
        self._algorithm.summarize_metrics()
        return steps

    def _run(self, *args, **kwargs):
        raise RuntimeError(
            "You should call self.run_async instead for async drivers")

    def stop(self):
        # finishes the entire program
        self._coord.request_stop()
        # Cancel all pending requests (including enqueues and dequeues),
        # so that no thread hangs before calling coord.should_stop()
        self._tfq.close_all()
        self._coord.join(self._threads)
        logging.info("All threads stopped")
class AsyncOffPolicyDriver(OffPolicyDriver):
    """
    A driver that enables async training. The definition of 'async' is that
    prediction and learning are not intervened like below (synchronous):
        pred -> learn -> pred -> learn -> ...
    Instead they are decoupled:
        pred -> pred -> pred -> ...
                   |
            (synchronize periodically)
                   |
        learn -> learn -> learn -> ...
    And more importantly, a learner or predictor may only operate on a subset of
    environments at a time.
    """
    def __init__(self,
                 env_f: Callable,
                 algorithm: OffPolicyAlgorithm,
                 num_envs=1,
                 num_actor_queues=1,
                 unroll_length=8,
                 learn_queue_cap=1,
                 actor_queue_cap=1,
                 observers=[],
                 metrics=[],
                 exp_replayer="one_time",
                 debug_summaries=False,
                 summarize_grads_and_vars=False,
                 train_step_counter=None):
        """
        Args:
            env_f (Callable): a function with 0 args that creates an environment
            algorithm (OffPolicyAlgorithm):
            num_envs (int): the number of environments to run asynchronously.
                Note: each environment itself could be a tf_agent batched
                environment. So the actual total number of environments is
                `num_envs` * `env_f().batch_size`. However, `env_f().batch_size`
                is transparent to this driver. So all the queues operate on the
                assumption of `num_envs` environments. Each environment is
                exclusively owned by only one `EnvThread`.
            num_actor_queues (int): number of actor queues. Each queue is
                exclusivly owned by just one actor thread.
            unroll_length (int): number of time steps each environment proceeds
                before sending the steps to the learner queue
            learn_queue_cap (int): the learner queue capacity determines how many
                environments contribute to the training data for each training
                iteration
            actor_queue_cap (int): the actor queue capacity determines how many
                environments contribute to the data for each prediction forward
                in an `ActorThread`. To prevent deadlock, it's required that
                `actor_queue_cap` * `num_actor_queues` <= `num_envs`.
            observers (list[Callable]): An optional list of observers that are
                updated after every step in the environment. Each observer is a
                callable(time_step.Trajectory).
            metrics (list[TFStepMetric]): An optiotional list of metrics.
            exp_replayer (str): a string that indicates which ExperienceReplayer
                to use.
            debug_summaries (bool): A bool to gather debug summaries.
            summarize_grads_and_vars (bool): If True, gradient and network
                variable summaries will be written during training.
            train_step_counter (tf.Variable): An optional counter to increment
                every time the a new iteration is started. If None, it will use
                tf.summary.experimental.get_step(). If this is still None, a
                counter will be created.
        """
        super(AsyncOffPolicyDriver,
              self).__init__(env=env_f(),
                             algorithm=algorithm,
                             exp_replayer=exp_replayer,
                             observers=observers,
                             metrics=metrics,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter)

        # create threads
        self._coord = tf.train.Coordinator()
        self._tfq = TFQueues(num_envs,
                             self._env.batch_size,
                             learn_queue_cap,
                             actor_queue_cap,
                             time_step_spec=self._time_step_spec,
                             policy_step_spec=self._pred_policy_step_spec,
                             act_dist_param_spec=self._action_dist_param_spec,
                             unroll_length=unroll_length,
                             num_actor_queues=num_actor_queues)
        actor_threads = [
            ActorThread(name="actor{}".format(i),
                        coord=self._coord,
                        algorithm=self._algorithm,
                        tf_queues=self._tfq,
                        id=i,
                        observation_transformer=self._observation_transformer)
            for i in range(num_actor_queues)
        ]
        env_threads = [
            EnvThread(name="env{}".format(i),
                      coord=self._coord,
                      env_f=env_f,
                      tf_queues=self._tfq,
                      unroll_length=unroll_length,
                      id=i,
                      actor_id=i % num_actor_queues) for i in range(num_envs)
        ]
        self._log_thread = LogThread(name="logging",
                                     num_envs=num_envs,
                                     env_batch_size=self._env.batch_size,
                                     observers=observers,
                                     metrics=metrics,
                                     coord=self._coord,
                                     queue=self._tfq.log_queue)
        self._threads = actor_threads + env_threads + [self._log_thread]

    def get_step_metrics(self):
        """See PolicyDriver.get_step_metrics()"""
        return self._log_thread.metrics[:2]

    def get_metrics(self):
        """See PolicyDriver.get_metrics()"""
        return self._log_thread.metrics

    def start(self):
        """Starts all env, actor, and log threads."""
        for th in self._threads:
            th.setDaemon(True)
            th.start()
        logging.info("All threads started")

    @tf.function
    def get_training_exps(self):
        """
        Get training experiences from the learning queue

        Returns:
            exp (Experience):
            env_id (tf.tensor): if not None, has the shape of (`num_envs`). Each
                element of `env_ids` indicates which batched env the data come from.
            steps (int): how many environment steps this batch of exps contain
        """
        batch = make_learning_batch(*self._tfq.learn_queue.dequeue_all())
        # convert the batch to the experience format
        exp = make_experience(batch.time_step, batch.policy_step,
                              batch.act_dist_param)
        # make the exp batch major for each environment
        exp = tf.nest.map_structure(lambda e: common.transpose2(e, 1, 2), exp)
        num_envs, unroll_length, env_batch_size \
            = batch.time_step.observation.shape[:3]
        steps = num_envs * unroll_length * env_batch_size
        return exp, batch.env_id, steps

    def run_async(self):
        """
        Each call of run_async() will wait for a learning batch to be filled in
        by the env threads.
        Running in the eager mode. The reason is that currently
        OnetimeExperienceReplayer is incompatible with Graph mode because it
        replays by a temporary variable.

        Output:
            steps (int): the total number of unrolled steps
        """
        exp, env_id, steps = self.get_training_exps()
        for ob in self._exp_observers:
            ob(exp, env_id)
        return steps

    def _run(self, *args, **kwargs):
        raise RuntimeError(
            "You should call self.run_async instead for async drivers")

    def stop(self):
        # finishes the entire program
        self._coord.request_stop()
        # Cancel all pending requests (including enqueues and dequeues),
        # so that no thread hangs before calling coord.should_stop()
        self._tfq.close_all()
        self._coord.join(self._threads)
        logging.info("All threads stopped")