def _step(self, time_step, policy_state): policy_state = common.reset_state_if_necessary(policy_state, self._initial_state, time_step.is_first()) policy_step = common.algorithm_step( self._algorithm, self._observation_transformer, time_step, state=policy_state, training=self._training, greedy_predict=self._greedy_predict) action = common.sample_action_distribution(policy_step.action) next_time_step = self._env_step(action) if self._observers: traj = from_transition(time_step, policy_step._replace(action=action), next_time_step) for observer in self._observers: observer(traj) if self._exp_observers: action_distribution_param = common.get_distribution_params( policy_step.action) exp = make_experience( time_step, policy_step._replace(action=action), action_distribution=action_distribution_param, state=policy_state if self._use_rollout_state else ()) for observer in self._exp_observers: observer(exp) return next_time_step, policy_step, action
def _step(self, algorithm, time_step, state): time_step = algorithm.transform_timestep(time_step) policy_step = common.algorithm_step(algorithm.rollout, time_step, state) action_dist_param = common.get_distribution_params(policy_step.action) policy_step = common.sample_policy_action(policy_step) return policy_step, action_dist_param
def _train_loop_body(counter, policy_state, training_info_ta): exp = tf.nest.map_structure(lambda ta: ta.read(counter), experience_ta) collect_action_distribution_param = exp.action_distribution collect_action_distribution = nested_distributions_from_specs( self._action_distribution_spec, collect_action_distribution_param) exp = exp._replace(action_distribution=collect_action_distribution) policy_state = common.reset_state_if_necessary( policy_state, initial_train_state, tf.equal(exp.step_type, StepType.FIRST)) policy_step = common.algorithm_step(self.train_step, exp, policy_state) action_dist_param = common.get_distribution_params( policy_step.action) training_info = TrainingInfo(action_distribution=action_dist_param, info=policy_step.info) training_info_ta = tf.nest.map_structure( lambda ta, x: ta.write(counter, x), training_info_ta, training_info) counter += 1 return [counter, policy_step.state, training_info_ta]
def _step(self, algorithm, time_step, state): policy_step = common.algorithm_step(algorithm, self._ob_transformer, time_step, state, greedy_predict=False, training=False) action_dist_param = common.get_distribution_params(policy_step.action) policy_step = common.sample_policy_action(policy_step) return policy_step, action_dist_param
def after_train(self, training_info): """Adjust actor parameter according to KL-divergence.""" exp_array = TracExperience( observation=training_info.info.observation, step_type=training_info.step_type, action_param=common.get_distribution_params( training_info.action_distribution), state=training_info.info.state) exp_array = common.create_and_unstack_tensor_array( exp_array, clear_after_read=False) dists, steps = self._trusted_updater.adjust_step( lambda: self._calc_change(exp_array), self._action_dist_clips) def _summarize(): with self.name_scope: for i, d in enumerate(tf.nest.flatten(dists)): tf.summary.scalar("unadjusted_action_dist/%s" % i, d) tf.summary.scalar("adjust_steps", steps) common.run_if(common.should_record_summaries(), _summarize) self._ac_algorithm.after_train( training_info._replace(info=training_info.info.ac))
def _train_loop_body(self, counter, time_step, policy_state, training_info_ta): next_time_step, policy_step, action = self._step( time_step, policy_state) action = tf.nest.map_structure(tf.stop_gradient, action) action_distribution_param = common.get_distribution_params( policy_step.action) training_info = make_training_info( action_distribution=action_distribution_param, action=action, reward=time_step.reward, discount=time_step.discount, step_type=time_step.step_type, info=policy_step.info) training_info_ta = tf.nest.map_structure( lambda ta, x: ta.write(counter, x), training_info_ta, training_info) counter += 1 return [counter, next_time_step, policy_step.state, training_info_ta]
def _iter(self, time_step, policy_state): """One training iteration.""" counter = tf.zeros((), tf.int32) batch_size = self._env.batch_size def create_ta(s): return tf.TensorArray(dtype=s.dtype, size=self._train_interval + 1, element_shape=tf.TensorShape( [batch_size]).concatenate(s.shape)) training_info_ta = tf.nest.map_structure(create_ta, self._training_info_spec) with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape: tape.watch(self._trainable_variables) [counter, time_step, policy_state, training_info_ta ] = tf.while_loop(cond=lambda *_: True, body=self._train_loop_body, loop_vars=[ counter, time_step, policy_state, training_info_ta ], back_prop=True, parallel_iterations=1, maximum_iterations=self._train_interval, name='iter_loop') if self._final_step_mode == OnPolicyDriver.FINAL_STEP_SKIP: next_time_step, policy_step, action = self._step( time_step, policy_state) next_state = policy_step.state else: policy_step = common.algorithm_step(self._algorithm.rollout, self._observation_transformer, time_step, policy_state) action = common.sample_action_distribution(policy_step.action) next_time_step = time_step next_state = policy_state action_distribution_param = common.get_distribution_params( policy_step.action) final_training_info = make_training_info( action_distribution=action_distribution_param, action=action, reward=time_step.reward, discount=time_step.discount, step_type=time_step.step_type, info=policy_step.info) with tape: training_info_ta = tf.nest.map_structure( lambda ta, x: ta.write(counter, x), training_info_ta, final_training_info) training_info = tf.nest.map_structure(lambda ta: ta.stack(), training_info_ta) action_distribution = nested_distributions_from_specs( self._algorithm.action_distribution_spec, training_info.action_distribution) training_info = training_info._replace( action_distribution=action_distribution) loss_info, grads_and_vars = self._algorithm.train_complete( tape, training_info) del tape self._training_summary(training_info, loss_info, grads_and_vars) self._train_step_counter.assign_add(1) return next_time_step, next_state