Exemplo n.º 1
0
def zeros_from_spec(nested_spec, batch_size):
    """Create nested zero Tensors or Distributions.

    A zero tensor with shape[0]=`batch_size is created for each TensorSpec and
    A distribution with all the parameters as zero Tensors is created for each
    DistributionSpec.

    Args:
        nested_spec (nested TensorSpec or DistributionSpec):
        batch_size (int): batch size added as the first dimension to the shapes
             in TensorSpec
    Returns:
        nested Tensor or Distribution
    """

    def _zero_tensor(spec):
        if batch_size is None:
            shape = spec.shape
        else:
            spec_shape = tf.convert_to_tensor(value=spec.shape, dtype=tf.int32)
            shape = tf.concat(([batch_size], spec_shape), axis=0)
        dtype = spec.dtype
        return tf.zeros(shape, dtype)

    param_spec = nest_utils.to_distribution_param_spec(nested_spec)
    params = tf.nest.map_structure(_zero_tensor, param_spec)
    return nest_utils.params_to_distributions(params, nested_spec)
Exemplo n.º 2
0
    def rollout(self, max_num_steps, time_step, policy_state):
        counter = tf.zeros((), tf.int32)
        batch_size = self._env.batch_size
        maximum_iterations = math.ceil(max_num_steps / self._env.batch_size)

        def create_ta(s):
            return tf.TensorArray(dtype=s.dtype,
                                  size=maximum_iterations,
                                  element_shape=tf.TensorShape(
                                      [batch_size]).concatenate(s.shape))

        training_info_ta = tf.nest.map_structure(
            create_ta,
            self._training_info_spec._replace(
                rollout_info=nest_utils.to_distribution_param_spec(
                    self._training_info_spec.rollout_info)))

        [counter, time_step, policy_state, training_info_ta] = tf.while_loop(
            cond=lambda *_: True,
            body=self._rollout_loop_body,
            loop_vars=[counter, time_step, policy_state, training_info_ta],
            maximum_iterations=maximum_iterations,
            back_prop=False,
            name="rollout_loop")

        training_info = tf.nest.map_structure(lambda ta: ta.stack(),
                                              training_info_ta)

        training_info = nest_utils.params_to_distributions(
            training_info, self._training_info_spec)

        self._algorithm.summarize_rollout(training_info)
        self._algorithm.summarize_metrics()

        return time_step, policy_state
Exemplo n.º 3
0
 def _update_total_dists(new_action, exp, total_dists):
     old_action = nest_utils.params_to_distributions(
         exp.action_param, self._action_distribution_spec)
     dists = nest_map(_dist, old_action, new_action)
     valid_masks = tf.cast(tf.not_equal(exp.step_type, StepType.LAST),
                           tf.float32)
     dists = nest_map(lambda kl: tf.reduce_sum(kl * valid_masks), dists)
     return nest_map(lambda x, y: x + y, total_dists, dists)
Exemplo n.º 4
0
    def _iter(self, time_step, policy_state):
        """One training iteration."""
        counter = tf.zeros((), tf.int32)
        batch_size = self._env.batch_size

        def create_ta(s):
            return tf.TensorArray(dtype=s.dtype,
                                  size=self._train_interval,
                                  element_shape=tf.TensorShape(
                                      [batch_size]).concatenate(s.shape))

        training_info_ta = tf.nest.map_structure(
            create_ta,
            self._training_info_spec._replace(
                info=nest_utils.to_distribution_param_spec(
                    self._training_info_spec.info)))

        with tf.GradientTape(watch_accessed_variables=False,
                             persistent=True) as tape:
            tape.watch(self._trainable_variables)
            [counter, next_time_step, next_state, training_info_ta
             ] = tf.while_loop(cond=lambda *_: True,
                               body=self._train_loop_body,
                               loop_vars=[
                                   counter, time_step, policy_state,
                                   training_info_ta
                               ],
                               back_prop=True,
                               parallel_iterations=1,
                               maximum_iterations=self._train_interval,
                               name='iter_loop')

            training_info = tf.nest.map_structure(lambda ta: ta.stack(),
                                                  training_info_ta)

            training_info = nest_utils.params_to_distributions(
                training_info, self._training_info_spec)

        loss_info, grads_and_vars = self._algorithm.train_complete(
            tape, training_info)

        del tape

        self._algorithm.summarize_train(training_info, loss_info,
                                        grads_and_vars)
        self._algorithm.summarize_metrics()

        common.get_global_counter().assign_add(1)

        return [next_time_step, next_state]
Exemplo n.º 5
0
        def _train_loop_body(counter, policy_state, info_ta):
            exp = tf.nest.map_structure(lambda ta: ta.read(counter),
                                        experience_ta)
            exp = nest_utils.params_to_distributions(
                exp, self.processed_experience_spec)
            policy_state = common.reset_state_if_necessary(
                policy_state, initial_train_state,
                tf.equal(exp.step_type, StepType.FIRST))

            with tf.name_scope(scope):
                policy_step = self.train_step(exp, policy_state)

            info_ta = tf.nest.map_structure(
                lambda ta, x: ta.write(counter, x), info_ta,
                nest_utils.distributions_to_params(policy_step.info))

            counter += 1

            return [counter, policy_step.state, info_ta]
Exemplo n.º 6
0
    def _update(self, experience, weight):
        batch_size = tf.shape(experience.step_type)[1]
        counter = tf.zeros((), tf.int32)
        initial_train_state = common.get_initial_policy_state(
            batch_size, self.train_state_spec)
        if self._use_rollout_state:
            first_train_state = tf.nest.map_structure(
                lambda state: state[0, ...], experience.state)
        else:
            first_train_state = initial_train_state
        num_steps = tf.shape(experience.step_type)[0]

        def create_ta(s):
            # TensorArray cannot use Tensor (batch_size) as element_shape
            ta_batch_size = experience.step_type.shape[1]
            return tf.TensorArray(dtype=s.dtype,
                                  size=num_steps,
                                  element_shape=tf.TensorShape(
                                      [ta_batch_size]).concatenate(s.shape))

        experience_ta = tf.nest.map_structure(
            create_ta,
            nest_utils.to_distribution_param_spec(
                self.processed_experience_spec))
        experience_ta = tf.nest.map_structure(
            lambda elem, ta: ta.unstack(elem), experience, experience_ta)
        info_ta = tf.nest.map_structure(
            create_ta,
            nest_utils.to_distribution_param_spec(self.train_step_info_spec))

        scope = get_current_scope()

        def _train_loop_body(counter, policy_state, info_ta):
            exp = tf.nest.map_structure(lambda ta: ta.read(counter),
                                        experience_ta)
            exp = nest_utils.params_to_distributions(
                exp, self.processed_experience_spec)
            policy_state = common.reset_state_if_necessary(
                policy_state, initial_train_state,
                tf.equal(exp.step_type, StepType.FIRST))

            with tf.name_scope(scope):
                policy_step = self.train_step(exp, policy_state)

            info_ta = tf.nest.map_structure(
                lambda ta, x: ta.write(counter, x), info_ta,
                nest_utils.distributions_to_params(policy_step.info))

            counter += 1

            return [counter, policy_step.state, info_ta]

        with tf.GradientTape(persistent=True,
                             watch_accessed_variables=False) as tape:
            tape.watch(self.trainable_variables)
            [_, _, info_ta] = tf.while_loop(
                cond=lambda counter, *_: tf.less(counter, num_steps),
                body=_train_loop_body,
                loop_vars=[counter, first_train_state, info_ta],
                back_prop=True,
                name="train_loop")
            info = tf.nest.map_structure(lambda ta: ta.stack(), info_ta)
            info = nest_utils.params_to_distributions(
                info, self.train_step_info_spec)
            experience = nest_utils.params_to_distributions(
                experience, self.processed_experience_spec)
            training_info = TrainingInfo(action=experience.action,
                                         reward=experience.reward,
                                         discount=experience.discount,
                                         step_type=experience.step_type,
                                         rollout_info=experience.rollout_info,
                                         info=info,
                                         env_id=experience.env_id)

        loss_info, grads_and_vars = self.train_complete(
            tape=tape, training_info=training_info, weight=weight)

        del tape

        return training_info, loss_info, grads_and_vars
Exemplo n.º 7
0
    def _train(self, experience, num_updates, mini_batch_size,
               mini_batch_length, update_counter_every_mini_batch,
               should_summarize):
        """Train using experience."""
        experience = nest_utils.params_to_distributions(
            experience, self.experience_spec)
        experience = self.transform_timestep(experience)
        experience = self.preprocess_experience(experience)
        experience = nest_utils.distributions_to_params(experience)

        length = experience.step_type.shape[1]
        mini_batch_length = (mini_batch_length or length)
        assert length % mini_batch_length == 0, (
            "length=%s not a multiple of mini_batch_length=%s" %
            (length, mini_batch_length))

        if len(tf.nest.flatten(
                self.train_state_spec)) > 0 and not self._use_rollout_state:
            if mini_batch_length == 1:
                logging.fatal(
                    "Should use TrainerConfig.use_rollout_state=True "
                    "for off-policy training of RNN when minibatch_length==1.")
            else:
                common.warning_once(
                    "Consider using TrainerConfig.use_rollout_state=True "
                    "for off-policy training of RNN.")

        experience = tf.nest.map_structure(
            lambda x: tf.reshape(
                x, common.concat_shape([-1, mini_batch_length],
                                       tf.shape(x)[2:])), experience)

        batch_size = tf.shape(experience.step_type)[0]
        mini_batch_size = (mini_batch_size or batch_size)

        def _make_time_major(nest):
            """Put the time dim to axis=0."""
            return tf.nest.map_structure(lambda x: common.transpose2(x, 0, 1),
                                         nest)

        scope = get_current_scope()

        for u in tf.range(num_updates):
            if mini_batch_size < batch_size:
                indices = tf.random.shuffle(
                    tf.range(tf.shape(experience.step_type)[0]))
                experience = tf.nest.map_structure(
                    lambda x: tf.gather(x, indices), experience)
            for b in tf.range(0, batch_size, mini_batch_size):
                if update_counter_every_mini_batch:
                    common.get_global_counter().assign_add(1)
                is_last_mini_batch = tf.logical_and(
                    tf.equal(u, num_updates - 1),
                    tf.greater_equal(b + mini_batch_size, batch_size))
                do_summary = tf.logical_or(is_last_mini_batch,
                                           update_counter_every_mini_batch)
                common.enable_summary(do_summary)
                batch = tf.nest.map_structure(
                    lambda x: x[b:tf.minimum(batch_size, b + mini_batch_size)],
                    experience)
                batch = _make_time_major(batch)
                # Tensorflow graph mode loses the original name scope here. We
                # need to restore the original name scope
                with tf.name_scope(scope):
                    training_info, loss_info, grads_and_vars = self._update(
                        batch,
                        weight=tf.cast(
                            tf.shape(batch.step_type)[1], tf.float32) /
                        float(mini_batch_size))
                if should_summarize:
                    if do_summary:
                        # Putting `if do_summary` under the above `with` statement
                        # does not help. Somehow `if` statement will also lose
                        # the original name scope.
                        with tf.name_scope(scope):
                            self.summarize_train(training_info, loss_info,
                                                 grads_and_vars)

        train_steps = batch_size * mini_batch_length * num_updates
        return train_steps