Beispiel #1
0
    def _step(self, time_step, policy_state):
        policy_state = common.reset_state_if_necessary(policy_state,
                                                       self._initial_state,
                                                       time_step.is_first())
        if self._mode == self.PREDICT:
            step_func = functools.partial(self._algorithm.predict,
                                          epsilon_greedy=self._epsilon_greedy)
        elif self._mode == self.ON_POLICY_TRAINING:
            step_func = functools.partial(self._algorithm.rollout,
                                          mode=RLAlgorithm.ON_POLICY_TRAINING)
        elif self._mode == self.OFF_POLICY_TRAINING:
            step_func = functools.partial(self._algorithm.rollout,
                                          mode=RLAlgorithm.ROLLOUT)
        else:
            raise ValueError()
        transformed_time_step = self._algorithm.transform_timestep(time_step)
        policy_step = step_func(transformed_time_step, policy_state)

        next_time_step = self._env_step(policy_step.action)
        if self._observers:
            traj = from_transition(time_step, policy_step._replace(info=()),
                                   next_time_step)
            for observer in self._observers:
                observer(traj)
        if self._algorithm.exp_observers and self._training:
            policy_step = nest_utils.distributions_to_params(policy_step)
            exp = make_experience(time_step, policy_step, policy_state)
            self._algorithm.observe(exp)

        return next_time_step, policy_step, transformed_time_step
Beispiel #2
0
    def observe(self, exp: Experience):
        """An algorithm can override to manipulate experience.

        Args:
            exp (Experience): The shapes can be either [Q, T, B, ...] or
                [B, ...], where Q is `learn_queue_cap` in `AsyncOffPolicyDriver`,
                T is the sequence length, and B is the batch size of the batched
                environment.
        """
        if not self._use_rollout_state:
            exp = exp._replace(state=())
        exp = nest_utils.distributions_to_params(exp)
        for observer in self._exp_observers:
            observer(exp)
Beispiel #3
0
    def _dequeue_and_step(self, algorithm):
        time_step, policy_state, env_ids = self._actor_q.dequeue_all()
        # pack
        time_step = tf.nest.map_structure(common.flatten_once, time_step)
        policy_state = tf.nest.map_structure(common.flatten_once, policy_state)

        # prediction forward
        transformed_time_step = algorithm.transform_timestep(time_step)
        policy_step = algorithm.rollout(transformed_time_step,
                                        policy_state,
                                        mode=RLAlgorithm.ROLLOUT)

        # unpack
        policy_step = nest_utils.distributions_to_params(policy_step)
        policy_step = tf.nest.map_structure(
            lambda e: tf.reshape(e, [env_ids.shape[0], -1] + list(e.shape[1:])
                                 ), policy_step)
        return policy_step, env_ids
Beispiel #4
0
        def _train_loop_body(counter, policy_state, info_ta):
            exp = tf.nest.map_structure(lambda ta: ta.read(counter),
                                        experience_ta)
            exp = nest_utils.params_to_distributions(
                exp, self.processed_experience_spec)
            policy_state = common.reset_state_if_necessary(
                policy_state, initial_train_state,
                tf.equal(exp.step_type, StepType.FIRST))

            with tf.name_scope(scope):
                policy_step = self.train_step(exp, policy_state)

            info_ta = tf.nest.map_structure(
                lambda ta, x: ta.write(counter, x), info_ta,
                nest_utils.distributions_to_params(policy_step.info))

            counter += 1

            return [counter, policy_step.state, info_ta]
Beispiel #5
0
    def _rollout_loop_body(self, counter, time_step, policy_state,
                           training_info_ta):

        next_time_step, policy_step, transformed_time_step = self._step(
            time_step, policy_state)

        training_info = ds.TrainingInfo(
            action=policy_step.action,
            reward=transformed_time_step.reward,
            discount=transformed_time_step.discount,
            step_type=transformed_time_step.step_type,
            rollout_info=nest_utils.distributions_to_params(policy_step.info),
            env_id=transformed_time_step.env_id)

        training_info_ta = tf.nest.map_structure(
            lambda ta, x: ta.write(counter, x), training_info_ta,
            training_info)

        counter += 1

        return [counter, next_time_step, policy_step.state, training_info_ta]
Beispiel #6
0
    def after_train(self, training_info):
        """Adjust actor parameter according to KL-divergence."""
        action_param = nest_utils.distributions_to_params(
            training_info.info.action_distribution)
        exp_array = TracExperience(observation=training_info.info.observation,
                                   step_type=training_info.step_type,
                                   action_param=action_param,
                                   state=training_info.info.state)
        exp_array = common.create_and_unstack_tensor_array(
            exp_array, clear_after_read=False)
        dists, steps = self._trusted_updater.adjust_step(
            lambda: self._calc_change(exp_array), self._action_dist_clips)

        def _summarize():
            with self.name_scope:
                for i, d in enumerate(tf.nest.flatten(dists)):
                    tf.summary.scalar("unadjusted_action_dist/%s" % i, d)
                tf.summary.scalar("adjust_steps", steps)

        common.run_if(common.should_record_summaries(), _summarize)
        ac_info = training_info.info.ac._replace(
            action_distribution=training_info.info.action_distribution)
        self._ac_algorithm.after_train(training_info._replace(info=ac_info))
Beispiel #7
0
    def _train(self, experience, num_updates, mini_batch_size,
               mini_batch_length, update_counter_every_mini_batch,
               should_summarize):
        """Train using experience."""
        experience = nest_utils.params_to_distributions(
            experience, self.experience_spec)
        experience = self.transform_timestep(experience)
        experience = self.preprocess_experience(experience)
        experience = nest_utils.distributions_to_params(experience)

        length = experience.step_type.shape[1]
        mini_batch_length = (mini_batch_length or length)
        assert length % mini_batch_length == 0, (
            "length=%s not a multiple of mini_batch_length=%s" %
            (length, mini_batch_length))

        if len(tf.nest.flatten(
                self.train_state_spec)) > 0 and not self._use_rollout_state:
            if mini_batch_length == 1:
                logging.fatal(
                    "Should use TrainerConfig.use_rollout_state=True "
                    "for off-policy training of RNN when minibatch_length==1.")
            else:
                common.warning_once(
                    "Consider using TrainerConfig.use_rollout_state=True "
                    "for off-policy training of RNN.")

        experience = tf.nest.map_structure(
            lambda x: tf.reshape(
                x, common.concat_shape([-1, mini_batch_length],
                                       tf.shape(x)[2:])), experience)

        batch_size = tf.shape(experience.step_type)[0]
        mini_batch_size = (mini_batch_size or batch_size)

        def _make_time_major(nest):
            """Put the time dim to axis=0."""
            return tf.nest.map_structure(lambda x: common.transpose2(x, 0, 1),
                                         nest)

        scope = get_current_scope()

        for u in tf.range(num_updates):
            if mini_batch_size < batch_size:
                indices = tf.random.shuffle(
                    tf.range(tf.shape(experience.step_type)[0]))
                experience = tf.nest.map_structure(
                    lambda x: tf.gather(x, indices), experience)
            for b in tf.range(0, batch_size, mini_batch_size):
                if update_counter_every_mini_batch:
                    common.get_global_counter().assign_add(1)
                is_last_mini_batch = tf.logical_and(
                    tf.equal(u, num_updates - 1),
                    tf.greater_equal(b + mini_batch_size, batch_size))
                do_summary = tf.logical_or(is_last_mini_batch,
                                           update_counter_every_mini_batch)
                common.enable_summary(do_summary)
                batch = tf.nest.map_structure(
                    lambda x: x[b:tf.minimum(batch_size, b + mini_batch_size)],
                    experience)
                batch = _make_time_major(batch)
                # Tensorflow graph mode loses the original name scope here. We
                # need to restore the original name scope
                with tf.name_scope(scope):
                    training_info, loss_info, grads_and_vars = self._update(
                        batch,
                        weight=tf.cast(
                            tf.shape(batch.step_type)[1], tf.float32) /
                        float(mini_batch_size))
                if should_summarize:
                    if do_summary:
                        # Putting `if do_summary` under the above `with` statement
                        # does not help. Somehow `if` statement will also lose
                        # the original name scope.
                        with tf.name_scope(scope):
                            self.summarize_train(training_info, loss_info,
                                                 grads_and_vars)

        train_steps = batch_size * mini_batch_length * num_updates
        return train_steps