Beispiel #1
0
    def genkey_and_read(self, keynet: Callable, query, flatten_result=True):
        """Generate key and read.

        Args:
            keynet (Callable): keynet(query) is a tensor of shape
              (batch_size, num_keys * (dim + 1))
            query (Tensor): the query from which the keys are generated
            flatten_result (bool): If True, the result shape will be
               (batch_size, num_keys * dim), otherwise it is
               (batch_size, num_keys, dim)
        Returns:
            resutl Tensor: If flatten_result is True,
              its shape is (batch_size, num_keys * dim), otherwise it is
              (batch_size, num_keys, dim)

        """
        batch_size = tf.shape(query)[0]
        keys_and_scales = keynet(query)
        num_keys = keys_and_scales.shape[-1] // (self.dim + 1)
        assert num_keys * (self.dim + 1) == keys_and_scales.shape[-1]
        keys, scales = tf.split(
            keys_and_scales,
            num_or_size_splits=[num_keys * self.dim, num_keys],
            axis=-1)
        keys = tf.reshape(
            keys, concat_shape(tf.shape(keys)[:-1], [num_keys, self.dim]))
        scales = tf.math.softplus(tf.reshape(scales, tf.shape(keys)[:-1]))

        r = self.read(keys, scales)
        if flatten_result:
            r = tf.reshape(r, (batch_size, num_keys * self.dim))
        return r
Beispiel #2
0
    def _train(self, experience, num_updates, mini_batch_size,
               mini_batch_length):
        """Train using experience."""

        experience = self.transform_timestep(experience)
        experience = self.preprocess_experience(experience)

        length = experience.step_type.shape[1]
        mini_batch_length = (mini_batch_length or length)
        assert length % mini_batch_length == 0, (
            "length=%s not a multiple of mini_batch_length=%s" %
            (length, mini_batch_length))

        if len(tf.nest.flatten(
                self.train_state_spec)) > 0 and not self._use_rollout_state:
            if mini_batch_length == 1:
                logging.fatal(
                    "Should use TrainerConfig.use_rollout_state=True "
                    "for off-policy training of RNN when minibatch_length==1.")
            else:
                common.warning_once(
                    "Consider using TrainerConfig.use_rollout_state=True "
                    "for off-policy training of RNN.")

        experience = tf.nest.map_structure(
            lambda x: tf.reshape(
                x, common.concat_shape([-1, mini_batch_length],
                                       tf.shape(x)[2:])), experience)

        batch_size = tf.shape(experience.step_type)[0]
        mini_batch_size = (mini_batch_size or batch_size)

        def _make_time_major(nest):
            """Put the time dim to axis=0."""
            return tf.nest.map_structure(lambda x: common.transpose2(x, 0, 1),
                                         nest)

        for u in tf.range(num_updates):
            if mini_batch_size < batch_size:
                indices = tf.random.shuffle(
                    tf.range(tf.shape(experience.step_type)[0]))
                experience = tf.nest.map_structure(
                    lambda x: tf.gather(x, indices), experience)
            for b in tf.range(0, batch_size, mini_batch_size):
                batch = tf.nest.map_structure(
                    lambda x: x[b:tf.minimum(batch_size, b + mini_batch_size)],
                    experience)
                batch = _make_time_major(batch)
                training_info, loss_info, grads_and_vars = self._update(
                    batch,
                    weight=tf.cast(tf.shape(batch.step_type)[1], tf.float32) /
                    float(mini_batch_size))
                common.get_global_counter().assign_add(1)
                self.training_summary(training_info, loss_info, grads_and_vars)

        self.metric_summary()
        train_steps = batch_size * mini_batch_length * num_updates
        return train_steps
Beispiel #3
0
 def _predict(self, inputs, batch_size=None, training=True):
     if inputs is None:
         assert self._input_tensor_spec is None
         assert batch_size is not None
     else:
         tf.nest.assert_same_structure(inputs, self._input_tensor_spec)
         batch_size = tf.shape(tf.nest.flatten(inputs)[0])[0]
     shape = common.concat_shape([batch_size], [self._noise_dim])
     noise = tf.random.normal(shape=shape)
     gen_inputs = noise if inputs is None else [noise, inputs]
     if self._predict_net and not training:
         outputs = self._predict_net(gen_inputs)[0]
     else:
         outputs = self._net(gen_inputs)[0]
     return outputs, gen_inputs
Beispiel #4
0
 def preprocess_experience(self, exp: Experience):
     """Compute advantages and put it into exp.info."""
     advantages = value_ops.generalized_advantage_estimation(
         rewards=exp.reward,
         values=exp.info.value,
         step_types=exp.step_type,
         discounts=exp.discount * self._loss._gamma,
         td_lambda=self._loss._lambda,
         time_major=False)
     advantages = tf.concat([
         advantages,
         tf.zeros(shape=common.concat_shape(tf.shape(advantages)[:-1], [1]),
                  dtype=advantages.dtype)
     ],
                            axis=-1)
     returns = exp.info.value + advantages
     return exp._replace(info=PPOInfo(returns, advantages))
Beispiel #5
0
    def _train(self, experience, num_updates, mini_batch_size,
               mini_batch_length, update_counter_every_mini_batch,
               should_summarize):
        """Train using experience."""
        experience = nest_utils.params_to_distributions(
            experience, self.experience_spec)
        experience = self.transform_timestep(experience)
        experience = self.preprocess_experience(experience)
        experience = nest_utils.distributions_to_params(experience)

        length = experience.step_type.shape[1]
        mini_batch_length = (mini_batch_length or length)
        assert length % mini_batch_length == 0, (
            "length=%s not a multiple of mini_batch_length=%s" %
            (length, mini_batch_length))

        if len(tf.nest.flatten(
                self.train_state_spec)) > 0 and not self._use_rollout_state:
            if mini_batch_length == 1:
                logging.fatal(
                    "Should use TrainerConfig.use_rollout_state=True "
                    "for off-policy training of RNN when minibatch_length==1.")
            else:
                common.warning_once(
                    "Consider using TrainerConfig.use_rollout_state=True "
                    "for off-policy training of RNN.")

        experience = tf.nest.map_structure(
            lambda x: tf.reshape(
                x, common.concat_shape([-1, mini_batch_length],
                                       tf.shape(x)[2:])), experience)

        batch_size = tf.shape(experience.step_type)[0]
        mini_batch_size = (mini_batch_size or batch_size)

        def _make_time_major(nest):
            """Put the time dim to axis=0."""
            return tf.nest.map_structure(lambda x: common.transpose2(x, 0, 1),
                                         nest)

        scope = get_current_scope()

        for u in tf.range(num_updates):
            if mini_batch_size < batch_size:
                indices = tf.random.shuffle(
                    tf.range(tf.shape(experience.step_type)[0]))
                experience = tf.nest.map_structure(
                    lambda x: tf.gather(x, indices), experience)
            for b in tf.range(0, batch_size, mini_batch_size):
                if update_counter_every_mini_batch:
                    common.get_global_counter().assign_add(1)
                is_last_mini_batch = tf.logical_and(
                    tf.equal(u, num_updates - 1),
                    tf.greater_equal(b + mini_batch_size, batch_size))
                do_summary = tf.logical_or(is_last_mini_batch,
                                           update_counter_every_mini_batch)
                common.enable_summary(do_summary)
                batch = tf.nest.map_structure(
                    lambda x: x[b:tf.minimum(batch_size, b + mini_batch_size)],
                    experience)
                batch = _make_time_major(batch)
                # Tensorflow graph mode loses the original name scope here. We
                # need to restore the original name scope
                with tf.name_scope(scope):
                    training_info, loss_info, grads_and_vars = self._update(
                        batch,
                        weight=tf.cast(
                            tf.shape(batch.step_type)[1], tf.float32) /
                        float(mini_batch_size))
                if should_summarize:
                    if do_summary:
                        # Putting `if do_summary` under the above `with` statement
                        # does not help. Somehow `if` statement will also lose
                        # the original name scope.
                        with tf.name_scope(scope):
                            self.summarize_train(training_info, loss_info,
                                                 grads_and_vars)

        train_steps = batch_size * mini_batch_length * num_updates
        return train_steps
Beispiel #6
0
    def train(self,
              experience: Experience,
              num_updates=1,
              mini_batch_size=None,
              mini_batch_length=None):
        """Train using `experience`.

        Args:
            experience (Experience): experience from replay_buffer. It is
                assumed to be batch major.
            num_updates (int): number of optimization steps
            mini_batch_size (int): number of sequences for each minibatch
            mini_batch_length (int): the length of the sequence for each
                sample in the minibatch

        Returns:
            train_steps (int): the actual number of time steps that have been
                trained (a step might be trained multiple times)
        """

        experience = self._algorithm.preprocess_experience(experience)

        length = experience.step_type.shape[1]
        mini_batch_length = (mini_batch_length or length)
        assert length % mini_batch_length == 0, (
            "length=%s not a multiple of mini_batch_length=%s" %
            (length, mini_batch_length))

        if len(tf.nest.flatten(self._algorithm.train_state_spec)
               ) > 0 and not self._use_rollout_state:
            if mini_batch_length == 1:
                logging.fatal(
                    "Should use TrainerConfig.use_rollout_state=True "
                    "for off-policy training of RNN when minibatch_length==1.")
            else:
                warning_once(
                    "Consider using TrainerConfig.use_rollout_state=True "
                    "for off-policy training of RNN.")

        experience = tf.nest.map_structure(
            lambda x: tf.reshape(
                x, common.concat_shape([-1, mini_batch_length],
                                       tf.shape(x)[2:])), experience)

        batch_size = tf.shape(experience.step_type)[0]
        mini_batch_size = (mini_batch_size or batch_size)

        def _make_time_major(nest):
            """Put the time dim to axis=0."""
            return tf.nest.map_structure(lambda x: common.transpose2(x, 0, 1),
                                         nest)

        for u in tf.range(num_updates):
            if mini_batch_size < batch_size:
                indices = tf.random.shuffle(
                    tf.range(tf.shape(experience.step_type)[0]))
                experience = tf.nest.map_structure(
                    lambda x: tf.gather(x, indices), experience)
            for b in tf.range(0, batch_size, mini_batch_size):
                batch = tf.nest.map_structure(
                    lambda x: x[b:tf.minimum(batch_size, b + mini_batch_size)],
                    experience)
                batch = _make_time_major(batch)
                is_last_mini_batch = tf.logical_and(
                    tf.equal(u, num_updates - 1),
                    tf.greater_equal(b + mini_batch_size, batch_size))
                common.enable_summary(is_last_mini_batch)
                training_info, loss_info, grads_and_vars = self._update(
                    batch,
                    weight=tf.cast(tf.shape(batch.step_type)[1], tf.float32) /
                    float(mini_batch_size))
                if is_last_mini_batch:
                    self._training_summary(training_info, loss_info,
                                           grads_and_vars)

        self._train_step_counter.assign_add(1)
        train_steps = batch_size * mini_batch_length * num_updates
        return train_steps