Exemple #1
0
    def set_exp_replayer(self, exp_replayer: str):
        """Set experience replayer."""

        if exp_replayer == "one_time":
            self._exp_replayer = OnetimeExperienceReplayer()
        elif exp_replayer == "uniform":
            self._exp_replayer = SyncUniformExperienceReplayer(
                self._experience_spec, self._env_batch_size)
        else:
            raise ValueError("invalid experience replayer name")
        self.add_experience_observer(self._exp_replayer.observe)
Exemple #2
0
    def __init__(self,
                 env: TFEnvironment,
                 algorithm: OffPolicyAlgorithm,
                 exp_replayer: str,
                 observers=[],
                 use_rollout_state=False,
                 metrics=[],
                 debug_summaries=False,
                 summarize_grads_and_vars=False,
                 train_step_counter=None):
        """Create an OffPolicyDriver.

        Args:
            env (TFEnvironment): A TFEnvironment
            algorithm (OffPolicyAlgorithm): The algorithm for training
            exp_replayer (str): a string that indicates which ExperienceReplayer
                to use. Either "one_time" or "uniform".
            observers (list[Callable]): An optional list of observers that are
                updated after every step in the environment. Each observer is a
                callable(time_step.Trajectory).
            metrics (list[TFStepMetric]): An optional list of metrics.
            debug_summaries (bool): A bool to gather debug summaries.
            summarize_grads_and_vars (bool): If True, gradient and network
                variable summaries will be written during training.
            train_step_counter (tf.Variable): An optional counter to increment
                every time the a new iteration is started. If None, it will use
                tf.summary.experimental.get_step(). If this is still None, a
                counter will be created.
        """
        super(OffPolicyDriver, self).__init__(
            env=env,
            algorithm=algorithm,
            observers=observers,
            use_rollout_state=use_rollout_state,
            metrics=metrics,
            training=False,  # training can only be done by calling self.train()!
            greedy_predict=False,  # always use OnPolicyDriver for play/eval!
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=train_step_counter)

        self._prepare_specs(algorithm)
        self._trainable_variables = algorithm.trainable_variables
        if exp_replayer == "one_time":
            self._exp_replayer = OnetimeExperienceReplayer()
        elif exp_replayer == "uniform":
            self._exp_replayer = SyncUniformExperienceReplayer(
                self._experience_spec, self._env.batch_size)
        else:
            raise ValueError("invalid experience replayer name")
        self.add_experience_observer(self._exp_replayer.observe)
Exemple #3
0
    def set_exp_replayer(self, exp_replayer: str, num_envs):
        """Set experience replayer.

        Args:
            exp_replayer (str): type of experience replayer. One of ("one_time",
                "uniform")
            num_envs (int): the total number of environments from all batched
                environments.
        """
        if exp_replayer == "one_time":
            self._exp_replayer = OnetimeExperienceReplayer()
        elif exp_replayer == "uniform":
            exp_spec = nest_utils.to_distribution_param_spec(
                self.experience_spec)
            self._exp_replayer = SyncUniformExperienceReplayer(
                exp_spec, num_envs)
        else:
            raise ValueError("invalid experience replayer name")
        self.add_experience_observer(self._exp_replayer.observe)
Exemple #4
0
    def set_exp_replayer(self,
                         exp_replayer: str,
                         num_envs,
                         num_actors=0,
                         unroll_length=0,
                         learn_queue_cap=0):
        """Set experience replayer.

        Args:
            exp_replayer (str): type of experience replayer. One of ("one_time",
                "uniform", "cycle_one_time")
            num_envs (int): the total number of environments from all batched
                environments/actors, which is num_actors * batch_size.
            num_actors (int): number of async actors, required to be positive
                for cycle_one_time replayer.
            unroll_length (int): number of env steps to unroll.  Used in
                cycle_one_time replayer.
            learn_queue_cap (int): number of actors to use for each mini-batch.
        """
        if exp_replayer == "one_time":
            self._exp_replayer = OnetimeExperienceReplayer()
        else:
            exp_spec = nest_utils.to_distribution_param_spec(
                self.experience_spec)
            if exp_replayer == "uniform":
                self._exp_replayer = SyncUniformExperienceReplayer(
                    exp_spec, num_envs)
            elif exp_replayer == "cycle_one_time":
                assert num_actors > 0
                assert unroll_length > 0
                self._exp_replayer = CyclicOneTimeExperienceReplayer(
                    exp_spec, num_envs, num_actors, unroll_length,
                    learn_queue_cap)
            else:
                raise ValueError("invalid experience replayer name")
        self.add_experience_observer(self._exp_replayer.observe)
Exemple #5
0
class OffPolicyAlgorithm(RLAlgorithm):
    """
       OffPolicyAlgorithm works with alf.drivers.off_policy_driver to do training

       User needs to implement rollout() and train_step().

       rollout() is called to generate actions for every environment step.

       train_step() is called to generate necessary information for training.

       The following is the pseudo code to illustrate how OffPolicyAlgorithm is used
       with OffPolicyDriver:

       ```python
        # (1) collect stage
        for _ in range(steps_per_collection):
            # collect experience and store to replay buffer
            policy_step = rollout(time_step, policy_step.state)
            experience = make_experience(time_step, policy_step)
            store experience to replay buffer
            action = sample action from policy_step.action
            time_step = env.step(action)

        # (2) train stage
        for _ in range(training_per_collection):
            # sample experiences and perform training
            experiences = sample batch from replay_buffer
            with tf.GradientTape() as tape:
                batched_training_info = []
                for experience in experiences:
                    policy_step = train_step(experience, state)
                    train_info = make_training_info(info, ...)
                    write train_info to batched_training_info
                train_complete(tape, batched_training_info,...)
    ```
    """
    @property
    def exp_replayer(self):
        """Return experience replayer."""
        return self._exp_replayer

    def predict(self, time_step: ActionTimeStep, state=None):
        """Default implementation of predict.

        Subclass may override.
        """
        policy_step = self._rollout_partial_state(time_step, state)
        return policy_step._replace(info=())

    def rollout(self, time_step: ActionTimeStep, state=None):
        """Base implementation of rollout for OffPolicyAlgorithm.

        Calls _rollout_full_state or _rollout_partial_state based on
        use_rollout_state.

        Subclass may override.

        Args:
            time_step (ActionTimeStep):
            state (nested Tensor): should be consistent with train_state_spec
        Returns:
            policy_step (PolicyStep):
              action (nested tf.distribution): should be consistent with
                `action_distribution_spec`
              state (nested Tensor): should be consistent with `train_state_spec`
              info (nested Tensor): everything necessary for training. Note that
                ("action_distribution", "action", "reward", "discount",
                "is_last") are automatically collected by OnPolicyDriver. So
                the user only need to put other stuff (e.g. value estimation)
                into `policy_step.info`
        """
        if self._use_rollout_state and self._is_rnn:
            return self._rollout_full_state(time_step, state)
        else:
            return self._rollout_partial_state(time_step, state)

    def _rollout_partial_state(self, time_step: ActionTimeStep, state=None):
        """Rollout without the full state for train_step().

        It is used for non-RNN model or RNN model without computating all states
        in train_state_spec. In the returned PolicyStep.state, you can use an
        empty tuple as a placeholder for those states that are not necessary for
        rollout.

        User needs to override this if _rollout_full_state() is not implemented.
        Args:
            time_step (ActionTimeStep):
            state (nested Tensor): should be consistent with train_state_spec
        Returns:
            policy_step (PolicyStep):
              action (nested tf.distribution): should be consistent with
                `action_distribution_spec`
              state (nested Tensor): should be consistent with `train_state_spec`.
              info (nested Tensor): everything necessary for training. Note that
                ("action_distribution", "action", "reward", "discount",
                "is_last") are automatically collected by OnPolicyDriver. So
                the user only need to put other stuff (e.g. value estimation)
                into `policy_step.info`
        """
        return self._rollout_full_state(time_step, state)

    def _rollout_full_state(self, time_step: ActionTimeStep, state=None):
        """Rollout with full state for train_step().

        If you want to use the rollout state for off-policy training (by setting
        TrainerConfig.use_rollout=True), you need to implement this function.
        You need to compute all the states for the returned PolicyStep.state.

        Args:
            time_step (ActionTimeStep):
            state (nested Tensor): should be consistent with train_state_spec
        Returns:
            policy_step (PolicyStep):
              action (nested tf.distribution): should be consistent with
                `action_distribution_spec`
              state (nested Tensor): should be consistent with `train_state_spec`.
              info (nested Tensor): everything necessary for training. Note that
                ("action_distribution", "action", "reward", "discount",
                "is_last") are automatically collected by OnPolicyDriver. So
                the user only need to put other stuff (e.g. value estimation)
                into `policy_step.info`
        """
        raise NotImplementedError("_rollout_full_state is not implemented")

    @abc.abstractmethod
    def train_step(self, experience: Experience, state):
        """Perform one step of training computation.

        Args:
            experience (Experience):
            state (nested Tensor): should be consistent with train_state_spec

        Returns (PolicyStep):
            action (nested tf.distribution): should be consistent with
                `action_distribution_spec`
            state (nested Tensor): should be consistent with `train_state_spec`
            info (nested Tensor): everything necessary for training. Note that
                ("action_distribution", "action", "reward", "discount",
                "is_last") are automatically collected by OffPolicyDriver. So
                the user only need to put other stuff (e.g. value estimation)
                into `policy_step.info`
        """
        pass

    def preprocess_experience(self, experience: Experience):
        """Preprocess experience.

        preprocess_experience is called for the experiences got from replay
        buffer. An example is to calculate advantages and returns in PPOAlgorithm.

        The shapes of tensors in experience are assumed to be (B, T, ...)

        Args:
            experience (Experience): original experience
        Returns:
            processed experience
        """
        return experience

    def set_exp_replayer(self, exp_replayer: str):
        """Set experience replayer."""

        if exp_replayer == "one_time":
            self._exp_replayer = OnetimeExperienceReplayer()
        elif exp_replayer == "uniform":
            self._exp_replayer = SyncUniformExperienceReplayer(
                self._experience_spec, self._env_batch_size)
        else:
            raise ValueError("invalid experience replayer name")
        self.add_experience_observer(self._exp_replayer.observe)

    def observe(self, exp: Experience):
        """An algorithm can override to manipulate experience."""
        for observer in self._exp_observers:
            observer(exp)

    def prepare_off_policy_specs(self, time_step: ActionTimeStep):
        """Prepare various tensor specs for off_policy training.

        prepare_off_policy_specs is called by OffPolicyDriver._prepare_spec().

        """

        self._env_batch_size = time_step.step_type.shape[0]
        self._time_step_spec = common.extract_spec(time_step)
        initial_state = common.get_initial_policy_state(
            self._env_batch_size, self.train_state_spec)
        transformed_timestep = self.transform_timestep(time_step)
        policy_step = self.rollout(transformed_timestep, initial_state)
        info_spec = common.extract_spec(policy_step.info)

        self._action_distribution_spec = tf.nest.map_structure(
            common.to_distribution_spec, self.action_distribution_spec)
        self._action_dist_param_spec = tf.nest.map_structure(
            lambda spec: spec.input_params_spec,
            self._action_distribution_spec)

        self._experience_spec = Experience(
            step_type=self._time_step_spec.step_type,
            reward=self._time_step_spec.reward,
            discount=self._time_step_spec.discount,
            observation=self._time_step_spec.observation,
            prev_action=self._action_spec,
            action=self._action_spec,
            info=info_spec,
            action_distribution=self._action_dist_param_spec,
            state=self.train_state_spec if self._use_rollout_state else ())

        action_dist_params = common.zero_tensor_from_nested_spec(
            self._experience_spec.action_distribution, self._env_batch_size)
        action_dist = nested_distributions_from_specs(
            self._action_distribution_spec, action_dist_params)

        exp = Experience(step_type=time_step.step_type,
                         reward=time_step.reward,
                         discount=time_step.discount,
                         observation=time_step.observation,
                         prev_action=time_step.prev_action,
                         action=time_step.prev_action,
                         info=policy_step.info,
                         action_distribution=action_dist,
                         state=initial_state if self._use_rollout_state else
                         ())

        transformed_exp = self.transform_timestep(exp)
        processed_exp = self.preprocess_experience(transformed_exp)
        self._processed_experience_spec = self._experience_spec._replace(
            observation=common.extract_spec(processed_exp.observation),
            info=common.extract_spec(processed_exp.info))

        policy_step = common.algorithm_step(
            algorithm_step_func=self.train_step,
            time_step=processed_exp,
            state=initial_state)
        info_spec = common.extract_spec(policy_step.info)
        self._training_info_spec = TrainingInfo(
            action_distribution=self._action_dist_param_spec, info=info_spec)

    def train(self,
              num_updates=1,
              mini_batch_size=None,
              mini_batch_length=None,
              clear_replay_buffer=True):
        """Train algorithm.

        Args:
            num_updates (int): number of optimization steps
            mini_batch_size (int): number of sequences for each minibatch
            mini_batch_length (int): the length of the sequence for each
                sample in the minibatch
            clear_replay_buffer (bool): whether use all data in replay buffer to
                perform one update and then wiped clean

        Returns:
            train_steps (int): the actual number of time steps that have been
                trained (a step might be trained multiple times)
        """

        if mini_batch_size is None:
            mini_batch_size = self._exp_replayer.batch_size
        if clear_replay_buffer:
            experience = self._exp_replayer.replay_all()
            self._exp_replayer.clear()
        else:
            experience, _ = self._exp_replayer.replay(
                sample_batch_size=mini_batch_size,
                mini_batch_length=mini_batch_length)

        return self._train(experience, num_updates, mini_batch_size,
                           mini_batch_length)

    @tf.function
    def _train(self, experience, num_updates, mini_batch_size,
               mini_batch_length):
        """Train using experience."""

        experience = self.transform_timestep(experience)
        experience = self.preprocess_experience(experience)

        length = experience.step_type.shape[1]
        mini_batch_length = (mini_batch_length or length)
        assert length % mini_batch_length == 0, (
            "length=%s not a multiple of mini_batch_length=%s" %
            (length, mini_batch_length))

        if len(tf.nest.flatten(
                self.train_state_spec)) > 0 and not self._use_rollout_state:
            if mini_batch_length == 1:
                logging.fatal(
                    "Should use TrainerConfig.use_rollout_state=True "
                    "for off-policy training of RNN when minibatch_length==1.")
            else:
                common.warning_once(
                    "Consider using TrainerConfig.use_rollout_state=True "
                    "for off-policy training of RNN.")

        experience = tf.nest.map_structure(
            lambda x: tf.reshape(
                x, common.concat_shape([-1, mini_batch_length],
                                       tf.shape(x)[2:])), experience)

        batch_size = tf.shape(experience.step_type)[0]
        mini_batch_size = (mini_batch_size or batch_size)

        def _make_time_major(nest):
            """Put the time dim to axis=0."""
            return tf.nest.map_structure(lambda x: common.transpose2(x, 0, 1),
                                         nest)

        for u in tf.range(num_updates):
            if mini_batch_size < batch_size:
                indices = tf.random.shuffle(
                    tf.range(tf.shape(experience.step_type)[0]))
                experience = tf.nest.map_structure(
                    lambda x: tf.gather(x, indices), experience)
            for b in tf.range(0, batch_size, mini_batch_size):
                batch = tf.nest.map_structure(
                    lambda x: x[b:tf.minimum(batch_size, b + mini_batch_size)],
                    experience)
                batch = _make_time_major(batch)
                training_info, loss_info, grads_and_vars = self._update(
                    batch,
                    weight=tf.cast(tf.shape(batch.step_type)[1], tf.float32) /
                    float(mini_batch_size))
                common.get_global_counter().assign_add(1)
                self.training_summary(training_info, loss_info, grads_and_vars)

        self.metric_summary()
        train_steps = batch_size * mini_batch_length * num_updates
        return train_steps

    def _update(self, experience, weight):
        batch_size = tf.shape(experience.step_type)[1]
        counter = tf.zeros((), tf.int32)
        initial_train_state = common.get_initial_policy_state(
            batch_size, self.train_state_spec)
        if self._use_rollout_state:
            first_train_state = tf.nest.map_structure(
                lambda state: state[0, ...], experience.state)
        else:
            first_train_state = initial_train_state
        num_steps = tf.shape(experience.step_type)[0]

        def create_ta(s):
            # TensorArray cannot use Tensor (batch_size) as element_shape
            ta_batch_size = experience.step_type.shape[1]
            return tf.TensorArray(dtype=s.dtype,
                                  size=num_steps,
                                  element_shape=tf.TensorShape(
                                      [ta_batch_size]).concatenate(s.shape))

        experience_ta = tf.nest.map_structure(create_ta,
                                              self._processed_experience_spec)
        experience_ta = tf.nest.map_structure(
            lambda elem, ta: ta.unstack(elem), experience, experience_ta)
        training_info_ta = tf.nest.map_structure(create_ta,
                                                 self._training_info_spec)

        def _train_loop_body(counter, policy_state, training_info_ta):
            exp = tf.nest.map_structure(lambda ta: ta.read(counter),
                                        experience_ta)
            collect_action_distribution_param = exp.action_distribution
            collect_action_distribution = nested_distributions_from_specs(
                self._action_distribution_spec,
                collect_action_distribution_param)
            exp = exp._replace(action_distribution=collect_action_distribution)

            policy_state = common.reset_state_if_necessary(
                policy_state, initial_train_state,
                tf.equal(exp.step_type, StepType.FIRST))

            policy_step = common.algorithm_step(self.train_step, exp,
                                                policy_state)

            action_dist_param = common.get_distribution_params(
                policy_step.action)

            training_info = TrainingInfo(action_distribution=action_dist_param,
                                         info=policy_step.info)

            training_info_ta = tf.nest.map_structure(
                lambda ta, x: ta.write(counter, x), training_info_ta,
                training_info)

            counter += 1

            return [counter, policy_step.state, training_info_ta]

        with tf.GradientTape(persistent=True,
                             watch_accessed_variables=False) as tape:
            tape.watch(self.trainable_variables)
            [_, _, training_info_ta] = tf.while_loop(
                cond=lambda counter, *_: tf.less(counter, num_steps),
                body=_train_loop_body,
                loop_vars=[counter, first_train_state, training_info_ta],
                back_prop=True,
                name="train_loop")
            training_info = tf.nest.map_structure(lambda ta: ta.stack(),
                                                  training_info_ta)
            training_info = training_info._replace(
                action=experience.action,
                reward=experience.reward,
                discount=experience.discount,
                step_type=experience.step_type,
                collect_info=experience.info,
                collect_action_distribution=experience.action_distribution)

            action_distribution = nested_distributions_from_specs(
                self._action_distribution_spec,
                training_info.action_distribution)
            collect_action_distribution = nested_distributions_from_specs(
                self._action_distribution_spec,
                training_info.collect_action_distribution)
            training_info = training_info._replace(
                action_distribution=action_distribution,
                collect_action_distribution=collect_action_distribution)

        loss_info, grads_and_vars = self.train_complete(
            tape=tape, training_info=training_info, weight=weight)

        del tape

        return training_info, loss_info, grads_and_vars