def _step(self, time_step, policy_state): policy_state = common.reset_state_if_necessary(policy_state, self._initial_state, time_step.is_first()) if self._mode == self.PREDICT: step_func = functools.partial(self._algorithm.predict, epsilon_greedy=self._epsilon_greedy) elif self._mode == self.ON_POLICY_TRAINING: step_func = functools.partial(self._algorithm.rollout, mode=RLAlgorithm.ON_POLICY_TRAINING) elif self._mode == self.OFF_POLICY_TRAINING: step_func = functools.partial(self._algorithm.rollout, mode=RLAlgorithm.ROLLOUT) else: raise ValueError() transformed_time_step = self._algorithm.transform_timestep(time_step) policy_step = step_func(transformed_time_step, policy_state) next_time_step = self._env_step(policy_step.action) if self._observers: traj = from_transition(time_step, policy_step._replace(info=()), next_time_step) for observer in self._observers: observer(traj) if self._algorithm.exp_observers and self._training: policy_step = nest_utils.distributions_to_params(policy_step) exp = make_experience(time_step, policy_step, policy_state) self._algorithm.observe(exp) return next_time_step, policy_step, transformed_time_step
def observe(self, exp: Experience): """An algorithm can override to manipulate experience. Args: exp (Experience): The shapes can be either [Q, T, B, ...] or [B, ...], where Q is `learn_queue_cap` in `AsyncOffPolicyDriver`, T is the sequence length, and B is the batch size of the batched environment. """ if not self._use_rollout_state: exp = exp._replace(state=()) exp = nest_utils.distributions_to_params(exp) for observer in self._exp_observers: observer(exp)
def _dequeue_and_step(self, algorithm): time_step, policy_state, env_ids = self._actor_q.dequeue_all() # pack time_step = tf.nest.map_structure(common.flatten_once, time_step) policy_state = tf.nest.map_structure(common.flatten_once, policy_state) # prediction forward transformed_time_step = algorithm.transform_timestep(time_step) policy_step = algorithm.rollout(transformed_time_step, policy_state, mode=RLAlgorithm.ROLLOUT) # unpack policy_step = nest_utils.distributions_to_params(policy_step) policy_step = tf.nest.map_structure( lambda e: tf.reshape(e, [env_ids.shape[0], -1] + list(e.shape[1:]) ), policy_step) return policy_step, env_ids
def _train_loop_body(counter, policy_state, info_ta): exp = tf.nest.map_structure(lambda ta: ta.read(counter), experience_ta) exp = nest_utils.params_to_distributions( exp, self.processed_experience_spec) policy_state = common.reset_state_if_necessary( policy_state, initial_train_state, tf.equal(exp.step_type, StepType.FIRST)) with tf.name_scope(scope): policy_step = self.train_step(exp, policy_state) info_ta = tf.nest.map_structure( lambda ta, x: ta.write(counter, x), info_ta, nest_utils.distributions_to_params(policy_step.info)) counter += 1 return [counter, policy_step.state, info_ta]
def _rollout_loop_body(self, counter, time_step, policy_state, training_info_ta): next_time_step, policy_step, transformed_time_step = self._step( time_step, policy_state) training_info = ds.TrainingInfo( action=policy_step.action, reward=transformed_time_step.reward, discount=transformed_time_step.discount, step_type=transformed_time_step.step_type, rollout_info=nest_utils.distributions_to_params(policy_step.info), env_id=transformed_time_step.env_id) training_info_ta = tf.nest.map_structure( lambda ta, x: ta.write(counter, x), training_info_ta, training_info) counter += 1 return [counter, next_time_step, policy_step.state, training_info_ta]
def after_train(self, training_info): """Adjust actor parameter according to KL-divergence.""" action_param = nest_utils.distributions_to_params( training_info.info.action_distribution) exp_array = TracExperience(observation=training_info.info.observation, step_type=training_info.step_type, action_param=action_param, state=training_info.info.state) exp_array = common.create_and_unstack_tensor_array( exp_array, clear_after_read=False) dists, steps = self._trusted_updater.adjust_step( lambda: self._calc_change(exp_array), self._action_dist_clips) def _summarize(): with self.name_scope: for i, d in enumerate(tf.nest.flatten(dists)): tf.summary.scalar("unadjusted_action_dist/%s" % i, d) tf.summary.scalar("adjust_steps", steps) common.run_if(common.should_record_summaries(), _summarize) ac_info = training_info.info.ac._replace( action_distribution=training_info.info.action_distribution) self._ac_algorithm.after_train(training_info._replace(info=ac_info))
def _train(self, experience, num_updates, mini_batch_size, mini_batch_length, update_counter_every_mini_batch, should_summarize): """Train using experience.""" experience = nest_utils.params_to_distributions( experience, self.experience_spec) experience = self.transform_timestep(experience) experience = self.preprocess_experience(experience) experience = nest_utils.distributions_to_params(experience) length = experience.step_type.shape[1] mini_batch_length = (mini_batch_length or length) assert length % mini_batch_length == 0, ( "length=%s not a multiple of mini_batch_length=%s" % (length, mini_batch_length)) if len(tf.nest.flatten( self.train_state_spec)) > 0 and not self._use_rollout_state: if mini_batch_length == 1: logging.fatal( "Should use TrainerConfig.use_rollout_state=True " "for off-policy training of RNN when minibatch_length==1.") else: common.warning_once( "Consider using TrainerConfig.use_rollout_state=True " "for off-policy training of RNN.") experience = tf.nest.map_structure( lambda x: tf.reshape( x, common.concat_shape([-1, mini_batch_length], tf.shape(x)[2:])), experience) batch_size = tf.shape(experience.step_type)[0] mini_batch_size = (mini_batch_size or batch_size) def _make_time_major(nest): """Put the time dim to axis=0.""" return tf.nest.map_structure(lambda x: common.transpose2(x, 0, 1), nest) scope = get_current_scope() for u in tf.range(num_updates): if mini_batch_size < batch_size: indices = tf.random.shuffle( tf.range(tf.shape(experience.step_type)[0])) experience = tf.nest.map_structure( lambda x: tf.gather(x, indices), experience) for b in tf.range(0, batch_size, mini_batch_size): if update_counter_every_mini_batch: common.get_global_counter().assign_add(1) is_last_mini_batch = tf.logical_and( tf.equal(u, num_updates - 1), tf.greater_equal(b + mini_batch_size, batch_size)) do_summary = tf.logical_or(is_last_mini_batch, update_counter_every_mini_batch) common.enable_summary(do_summary) batch = tf.nest.map_structure( lambda x: x[b:tf.minimum(batch_size, b + mini_batch_size)], experience) batch = _make_time_major(batch) # Tensorflow graph mode loses the original name scope here. We # need to restore the original name scope with tf.name_scope(scope): training_info, loss_info, grads_and_vars = self._update( batch, weight=tf.cast( tf.shape(batch.step_type)[1], tf.float32) / float(mini_batch_size)) if should_summarize: if do_summary: # Putting `if do_summary` under the above `with` statement # does not help. Somehow `if` statement will also lose # the original name scope. with tf.name_scope(scope): self.summarize_train(training_info, loss_info, grads_and_vars) train_steps = batch_size * mini_batch_length * num_updates return train_steps