def _step(self, time_step, policy_state): policy_state = common.reset_state_if_necessary(policy_state, self._initial_state, time_step.is_first()) policy_step = common.algorithm_step( self._algorithm, self._observation_transformer, time_step, state=policy_state, training=self._training, greedy_predict=self._greedy_predict) action = common.sample_action_distribution(policy_step.action) next_time_step = self._env_step(action) if self._observers: traj = from_transition(time_step, policy_step._replace(action=action), next_time_step) for observer in self._observers: observer(traj) if self._exp_observers: action_distribution_param = common.get_distribution_params( policy_step.action) exp = make_experience( time_step, policy_step._replace(action=action), action_distribution=action_distribution_param, state=policy_state if self._use_rollout_state else ()) for observer in self._exp_observers: observer(exp) return next_time_step, policy_step, action
def _step(self, algorithm, time_step, state): time_step = algorithm.transform_timestep(time_step) policy_step = common.algorithm_step(algorithm.rollout, time_step, state) action_dist_param = common.get_distribution_params(policy_step.action) policy_step = common.sample_policy_action(policy_step) return policy_step, action_dist_param
def _train_loop_body(counter, policy_state, training_info_ta): exp = tf.nest.map_structure(lambda ta: ta.read(counter), experience_ta) collect_action_distribution_param = exp.action_distribution collect_action_distribution = nested_distributions_from_specs( self._action_distribution_spec, collect_action_distribution_param) exp = exp._replace(action_distribution=collect_action_distribution) policy_state = common.reset_state_if_necessary( policy_state, initial_train_state, tf.equal(exp.step_type, StepType.FIRST)) policy_step = common.algorithm_step(self.train_step, exp, policy_state) action_dist_param = common.get_distribution_params( policy_step.action) training_info = TrainingInfo(action_distribution=action_dist_param, info=policy_step.info) training_info_ta = tf.nest.map_structure( lambda ta, x: ta.write(counter, x), training_info_ta, training_info) counter += 1 return [counter, policy_step.state, training_info_ta]
def _step(self, algorithm, time_step, state): policy_step = common.algorithm_step(algorithm, self._ob_transformer, time_step, state, greedy_predict=False, training=False) action_dist_param = common.get_distribution_params(policy_step.action) policy_step = common.sample_policy_action(policy_step) return policy_step, action_dist_param
def _eval(self): global_step = get_global_counter() with tf.summary.record_if(True): eager_compute( metrics=self._eval_metrics, environment=self._eval_env, state_spec=self._algorithm.predict_state_spec, action_fn=lambda time_step, state: common.algorithm_step( algorithm_step_func=self._algorithm.greedy_predict, time_step=self._algorithm.transform_timestep(time_step), state=state), num_episodes=self._num_eval_episodes, step_metrics=self._driver.get_step_metrics(), train_step=global_step, summary_writer=self._eval_summary_writer, summary_prefix="Metrics") metric_utils.log_metrics(self._eval_metrics)
def _iter(self, time_step, policy_state): """One training iteration.""" counter = tf.zeros((), tf.int32) batch_size = self._env.batch_size def create_ta(s): return tf.TensorArray(dtype=s.dtype, size=self._train_interval + 1, element_shape=tf.TensorShape( [batch_size]).concatenate(s.shape)) training_info_ta = tf.nest.map_structure(create_ta, self._training_info_spec) with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape: tape.watch(self._trainable_variables) [counter, time_step, policy_state, training_info_ta ] = tf.while_loop(cond=lambda *_: True, body=self._train_loop_body, loop_vars=[ counter, time_step, policy_state, training_info_ta ], back_prop=True, parallel_iterations=1, maximum_iterations=self._train_interval, name='iter_loop') if self._final_step_mode == OnPolicyDriver.FINAL_STEP_SKIP: next_time_step, policy_step, action = self._step( time_step, policy_state) next_state = policy_step.state else: policy_step = common.algorithm_step(self._algorithm.rollout, self._observation_transformer, time_step, policy_state) action = common.sample_action_distribution(policy_step.action) next_time_step = time_step next_state = policy_state action_distribution_param = common.get_distribution_params( policy_step.action) final_training_info = make_training_info( action_distribution=action_distribution_param, action=action, reward=time_step.reward, discount=time_step.discount, step_type=time_step.step_type, info=policy_step.info) with tape: training_info_ta = tf.nest.map_structure( lambda ta, x: ta.write(counter, x), training_info_ta, final_training_info) training_info = tf.nest.map_structure(lambda ta: ta.stack(), training_info_ta) action_distribution = nested_distributions_from_specs( self._algorithm.action_distribution_spec, training_info.action_distribution) training_info = training_info._replace( action_distribution=action_distribution) loss_info, grads_and_vars = self._algorithm.train_complete( tape, training_info) del tape self._training_summary(training_info, loss_info, grads_and_vars) self._train_step_counter.assign_add(1) return next_time_step, next_state
def _prepare_specs(self, algorithm): """Prepare various tensor specs.""" def extract_spec(nest): return tf.nest.map_structure( lambda t: tf.TensorSpec(t.shape[1:], t.dtype), nest) time_step = self.get_initial_time_step() self._time_step_spec = extract_spec(time_step) self._action_spec = self._env.action_spec() policy_step = algorithm.predict(time_step, self._initial_state) info_spec = extract_spec(policy_step.info) self._pred_policy_step_spec = PolicyStep( action=self._action_spec, state=algorithm.predict_state_spec, info=info_spec) def _to_distribution_spec(spec): if isinstance(spec, tf.TensorSpec): return DistributionSpec(tfp.distributions.Deterministic, input_params_spec={"loc": spec}, sample_spec=spec) return spec self._action_distribution_spec = tf.nest.map_structure( _to_distribution_spec, algorithm.action_distribution_spec) self._action_dist_param_spec = tf.nest.map_structure( lambda spec: spec.input_params_spec, self._action_distribution_spec) self._experience_spec = Experience( step_type=self._time_step_spec.step_type, reward=self._time_step_spec.reward, discount=self._time_step_spec.discount, observation=self._time_step_spec.observation, prev_action=self._action_spec, action=self._action_spec, info=info_spec, action_distribution=self._action_dist_param_spec) action_dist_params = common.zero_tensor_from_nested_spec( self._experience_spec.action_distribution, self._env.batch_size) action_dist = nested_distributions_from_specs( self._action_distribution_spec, action_dist_params) exp = Experience(step_type=time_step.step_type, reward=time_step.reward, discount=time_step.discount, observation=time_step.observation, prev_action=time_step.prev_action, action=time_step.prev_action, info=policy_step.info, action_distribution=action_dist) processed_exp = algorithm.preprocess_experience(exp) self._processed_experience_spec = self._experience_spec._replace( info=extract_spec(processed_exp.info)) policy_step = common.algorithm_step( algorithm, ob_transformer=self._observation_transformer, time_step=exp, state=common.get_initial_policy_state(self._env.batch_size, algorithm.train_state_spec), training=True) info_spec = extract_spec(policy_step.info) self._training_info_spec = make_training_info( action=self._action_spec, action_distribution=self._action_dist_param_spec, step_type=self._time_step_spec.step_type, reward=self._time_step_spec.reward, discount=self._time_step_spec.discount, info=info_spec, collect_info=self._processed_experience_spec.info, collect_action_distribution=self._action_dist_param_spec)
def prepare_off_policy_specs(self, time_step: ActionTimeStep): """Prepare various tensor specs for off_policy training. prepare_off_policy_specs is called by OffPolicyDriver._prepare_spec(). """ self._env_batch_size = time_step.step_type.shape[0] self._time_step_spec = common.extract_spec(time_step) initial_state = common.get_initial_policy_state( self._env_batch_size, self.train_state_spec) transformed_timestep = self.transform_timestep(time_step) policy_step = self.rollout(transformed_timestep, initial_state) info_spec = common.extract_spec(policy_step.info) self._action_distribution_spec = tf.nest.map_structure( common.to_distribution_spec, self.action_distribution_spec) self._action_dist_param_spec = tf.nest.map_structure( lambda spec: spec.input_params_spec, self._action_distribution_spec) self._experience_spec = Experience( step_type=self._time_step_spec.step_type, reward=self._time_step_spec.reward, discount=self._time_step_spec.discount, observation=self._time_step_spec.observation, prev_action=self._action_spec, action=self._action_spec, info=info_spec, action_distribution=self._action_dist_param_spec, state=self.train_state_spec if self._use_rollout_state else ()) action_dist_params = common.zero_tensor_from_nested_spec( self._experience_spec.action_distribution, self._env_batch_size) action_dist = nested_distributions_from_specs( self._action_distribution_spec, action_dist_params) exp = Experience(step_type=time_step.step_type, reward=time_step.reward, discount=time_step.discount, observation=time_step.observation, prev_action=time_step.prev_action, action=time_step.prev_action, info=policy_step.info, action_distribution=action_dist, state=initial_state if self._use_rollout_state else ()) transformed_exp = self.transform_timestep(exp) processed_exp = self.preprocess_experience(transformed_exp) self._processed_experience_spec = self._experience_spec._replace( observation=common.extract_spec(processed_exp.observation), info=common.extract_spec(processed_exp.info)) policy_step = common.algorithm_step( algorithm_step_func=self.train_step, time_step=processed_exp, state=initial_state) info_spec = common.extract_spec(policy_step.info) self._training_info_spec = TrainingInfo( action_distribution=self._action_dist_param_spec, info=info_spec)