def genkey_and_read(self, keynet: Callable, query, flatten_result=True): """Generate key and read. Args: keynet (Callable): keynet(query) is a tensor of shape (batch_size, num_keys * (dim + 1)) query (Tensor): the query from which the keys are generated flatten_result (bool): If True, the result shape will be (batch_size, num_keys * dim), otherwise it is (batch_size, num_keys, dim) Returns: resutl Tensor: If flatten_result is True, its shape is (batch_size, num_keys * dim), otherwise it is (batch_size, num_keys, dim) """ batch_size = tf.shape(query)[0] keys_and_scales = keynet(query) num_keys = keys_and_scales.shape[-1] // (self.dim + 1) assert num_keys * (self.dim + 1) == keys_and_scales.shape[-1] keys, scales = tf.split( keys_and_scales, num_or_size_splits=[num_keys * self.dim, num_keys], axis=-1) keys = tf.reshape( keys, concat_shape(tf.shape(keys)[:-1], [num_keys, self.dim])) scales = tf.math.softplus(tf.reshape(scales, tf.shape(keys)[:-1])) r = self.read(keys, scales) if flatten_result: r = tf.reshape(r, (batch_size, num_keys * self.dim)) return r
def _train(self, experience, num_updates, mini_batch_size, mini_batch_length): """Train using experience.""" experience = self.transform_timestep(experience) experience = self.preprocess_experience(experience) length = experience.step_type.shape[1] mini_batch_length = (mini_batch_length or length) assert length % mini_batch_length == 0, ( "length=%s not a multiple of mini_batch_length=%s" % (length, mini_batch_length)) if len(tf.nest.flatten( self.train_state_spec)) > 0 and not self._use_rollout_state: if mini_batch_length == 1: logging.fatal( "Should use TrainerConfig.use_rollout_state=True " "for off-policy training of RNN when minibatch_length==1.") else: common.warning_once( "Consider using TrainerConfig.use_rollout_state=True " "for off-policy training of RNN.") experience = tf.nest.map_structure( lambda x: tf.reshape( x, common.concat_shape([-1, mini_batch_length], tf.shape(x)[2:])), experience) batch_size = tf.shape(experience.step_type)[0] mini_batch_size = (mini_batch_size or batch_size) def _make_time_major(nest): """Put the time dim to axis=0.""" return tf.nest.map_structure(lambda x: common.transpose2(x, 0, 1), nest) for u in tf.range(num_updates): if mini_batch_size < batch_size: indices = tf.random.shuffle( tf.range(tf.shape(experience.step_type)[0])) experience = tf.nest.map_structure( lambda x: tf.gather(x, indices), experience) for b in tf.range(0, batch_size, mini_batch_size): batch = tf.nest.map_structure( lambda x: x[b:tf.minimum(batch_size, b + mini_batch_size)], experience) batch = _make_time_major(batch) training_info, loss_info, grads_and_vars = self._update( batch, weight=tf.cast(tf.shape(batch.step_type)[1], tf.float32) / float(mini_batch_size)) common.get_global_counter().assign_add(1) self.training_summary(training_info, loss_info, grads_and_vars) self.metric_summary() train_steps = batch_size * mini_batch_length * num_updates return train_steps
def _predict(self, inputs, batch_size=None, training=True): if inputs is None: assert self._input_tensor_spec is None assert batch_size is not None else: tf.nest.assert_same_structure(inputs, self._input_tensor_spec) batch_size = tf.shape(tf.nest.flatten(inputs)[0])[0] shape = common.concat_shape([batch_size], [self._noise_dim]) noise = tf.random.normal(shape=shape) gen_inputs = noise if inputs is None else [noise, inputs] if self._predict_net and not training: outputs = self._predict_net(gen_inputs)[0] else: outputs = self._net(gen_inputs)[0] return outputs, gen_inputs
def preprocess_experience(self, exp: Experience): """Compute advantages and put it into exp.info.""" advantages = value_ops.generalized_advantage_estimation( rewards=exp.reward, values=exp.info.value, step_types=exp.step_type, discounts=exp.discount * self._loss._gamma, td_lambda=self._loss._lambda, time_major=False) advantages = tf.concat([ advantages, tf.zeros(shape=common.concat_shape(tf.shape(advantages)[:-1], [1]), dtype=advantages.dtype) ], axis=-1) returns = exp.info.value + advantages return exp._replace(info=PPOInfo(returns, advantages))
def _train(self, experience, num_updates, mini_batch_size, mini_batch_length, update_counter_every_mini_batch, should_summarize): """Train using experience.""" experience = nest_utils.params_to_distributions( experience, self.experience_spec) experience = self.transform_timestep(experience) experience = self.preprocess_experience(experience) experience = nest_utils.distributions_to_params(experience) length = experience.step_type.shape[1] mini_batch_length = (mini_batch_length or length) assert length % mini_batch_length == 0, ( "length=%s not a multiple of mini_batch_length=%s" % (length, mini_batch_length)) if len(tf.nest.flatten( self.train_state_spec)) > 0 and not self._use_rollout_state: if mini_batch_length == 1: logging.fatal( "Should use TrainerConfig.use_rollout_state=True " "for off-policy training of RNN when minibatch_length==1.") else: common.warning_once( "Consider using TrainerConfig.use_rollout_state=True " "for off-policy training of RNN.") experience = tf.nest.map_structure( lambda x: tf.reshape( x, common.concat_shape([-1, mini_batch_length], tf.shape(x)[2:])), experience) batch_size = tf.shape(experience.step_type)[0] mini_batch_size = (mini_batch_size or batch_size) def _make_time_major(nest): """Put the time dim to axis=0.""" return tf.nest.map_structure(lambda x: common.transpose2(x, 0, 1), nest) scope = get_current_scope() for u in tf.range(num_updates): if mini_batch_size < batch_size: indices = tf.random.shuffle( tf.range(tf.shape(experience.step_type)[0])) experience = tf.nest.map_structure( lambda x: tf.gather(x, indices), experience) for b in tf.range(0, batch_size, mini_batch_size): if update_counter_every_mini_batch: common.get_global_counter().assign_add(1) is_last_mini_batch = tf.logical_and( tf.equal(u, num_updates - 1), tf.greater_equal(b + mini_batch_size, batch_size)) do_summary = tf.logical_or(is_last_mini_batch, update_counter_every_mini_batch) common.enable_summary(do_summary) batch = tf.nest.map_structure( lambda x: x[b:tf.minimum(batch_size, b + mini_batch_size)], experience) batch = _make_time_major(batch) # Tensorflow graph mode loses the original name scope here. We # need to restore the original name scope with tf.name_scope(scope): training_info, loss_info, grads_and_vars = self._update( batch, weight=tf.cast( tf.shape(batch.step_type)[1], tf.float32) / float(mini_batch_size)) if should_summarize: if do_summary: # Putting `if do_summary` under the above `with` statement # does not help. Somehow `if` statement will also lose # the original name scope. with tf.name_scope(scope): self.summarize_train(training_info, loss_info, grads_and_vars) train_steps = batch_size * mini_batch_length * num_updates return train_steps
def train(self, experience: Experience, num_updates=1, mini_batch_size=None, mini_batch_length=None): """Train using `experience`. Args: experience (Experience): experience from replay_buffer. It is assumed to be batch major. num_updates (int): number of optimization steps mini_batch_size (int): number of sequences for each minibatch mini_batch_length (int): the length of the sequence for each sample in the minibatch Returns: train_steps (int): the actual number of time steps that have been trained (a step might be trained multiple times) """ experience = self._algorithm.preprocess_experience(experience) length = experience.step_type.shape[1] mini_batch_length = (mini_batch_length or length) assert length % mini_batch_length == 0, ( "length=%s not a multiple of mini_batch_length=%s" % (length, mini_batch_length)) if len(tf.nest.flatten(self._algorithm.train_state_spec) ) > 0 and not self._use_rollout_state: if mini_batch_length == 1: logging.fatal( "Should use TrainerConfig.use_rollout_state=True " "for off-policy training of RNN when minibatch_length==1.") else: warning_once( "Consider using TrainerConfig.use_rollout_state=True " "for off-policy training of RNN.") experience = tf.nest.map_structure( lambda x: tf.reshape( x, common.concat_shape([-1, mini_batch_length], tf.shape(x)[2:])), experience) batch_size = tf.shape(experience.step_type)[0] mini_batch_size = (mini_batch_size or batch_size) def _make_time_major(nest): """Put the time dim to axis=0.""" return tf.nest.map_structure(lambda x: common.transpose2(x, 0, 1), nest) for u in tf.range(num_updates): if mini_batch_size < batch_size: indices = tf.random.shuffle( tf.range(tf.shape(experience.step_type)[0])) experience = tf.nest.map_structure( lambda x: tf.gather(x, indices), experience) for b in tf.range(0, batch_size, mini_batch_size): batch = tf.nest.map_structure( lambda x: x[b:tf.minimum(batch_size, b + mini_batch_size)], experience) batch = _make_time_major(batch) is_last_mini_batch = tf.logical_and( tf.equal(u, num_updates - 1), tf.greater_equal(b + mini_batch_size, batch_size)) common.enable_summary(is_last_mini_batch) training_info, loss_info, grads_and_vars = self._update( batch, weight=tf.cast(tf.shape(batch.step_type)[1], tf.float32) / float(mini_batch_size)) if is_last_mini_batch: self._training_summary(training_info, loss_info, grads_and_vars) self._train_step_counter.assign_add(1) train_steps = batch_size * mini_batch_length * num_updates return train_steps