def _prepare_specs(self, algorithm): time_step_spec = self._env.time_step_spec() action_distribution_param_spec = tf.nest.map_structure( lambda spec: spec.input_params_spec, algorithm.action_distribution_spec) policy_step = algorithm.train_step(self.get_initial_time_step(), self._initial_state) info_spec = tf.nest.map_structure( lambda t: tf.TensorSpec(t.shape[1:], t.dtype), policy_step.info) self._training_info_spec = make_training_info( action_distribution=action_distribution_param_spec, action=self._env.action_spec(), step_type=time_step_spec.step_type, reward=time_step_spec.reward, discount=time_step_spec.discount, info=info_spec)
def _train_loop_body(counter, policy_state, training_info_ta): exp = tf.nest.map_structure(lambda ta: ta.read(counter), experience_ta) collect_action_distribution_param = exp.action_distribution collect_action_distribution = nested_distributions_from_specs( self._action_distribution_spec, collect_action_distribution_param) exp = exp._replace(action_distribution=collect_action_distribution) policy_state = common.reset_state_if_necessary( policy_state, initial_train_state, tf.equal(exp.step_type, StepType.FIRST)) policy_step = common.algorithm_step(self._algorithm, self._observation_transformer, exp, policy_state, training=True) action_dist_param = common.get_distribution_params( policy_step.action) training_info = make_training_info( action=exp.action, action_distribution=action_dist_param, reward=exp.reward, discount=exp.discount, step_type=exp.step_type, info=policy_step.info, collect_info=exp.info, collect_action_distribution=collect_action_distribution_param) training_info_ta = tf.nest.map_structure( lambda ta, x: ta.write(counter, x), training_info_ta, training_info) counter += 1 return [counter, policy_step.state, training_info_ta]
def _train_loop_body(self, counter, time_step, policy_state, training_info_ta): next_time_step, policy_step, action = self._step( time_step, policy_state) action = tf.nest.map_structure(tf.stop_gradient, action) action_distribution_param = common.get_distribution_params( policy_step.action) training_info = make_training_info( action_distribution=action_distribution_param, action=action, reward=time_step.reward, discount=time_step.discount, step_type=time_step.step_type, info=policy_step.info) training_info_ta = tf.nest.map_structure( lambda ta, x: ta.write(counter, x), training_info_ta, training_info) counter += 1 return [counter, next_time_step, policy_step.state, training_info_ta]
def _iter(self, time_step, policy_state): """One training iteration.""" counter = tf.zeros((), tf.int32) batch_size = self._env.batch_size def create_ta(s): return tf.TensorArray(dtype=s.dtype, size=self._train_interval + 1, element_shape=tf.TensorShape( [batch_size]).concatenate(s.shape)) training_info_ta = tf.nest.map_structure(create_ta, self._training_info_spec) with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape: tape.watch(self._trainable_variables) [counter, time_step, policy_state, training_info_ta ] = tf.while_loop(cond=lambda *_: True, body=self._train_loop_body, loop_vars=[ counter, time_step, policy_state, training_info_ta ], back_prop=True, parallel_iterations=1, maximum_iterations=self._train_interval, name='iter_loop') if self._final_step_mode == OnPolicyDriver.FINAL_STEP_SKIP: next_time_step, policy_step, action = self._step( time_step, policy_state) next_state = policy_step.state else: policy_step = common.algorithm_step(self._algorithm.rollout, self._observation_transformer, time_step, policy_state) action = common.sample_action_distribution(policy_step.action) next_time_step = time_step next_state = policy_state action_distribution_param = common.get_distribution_params( policy_step.action) final_training_info = make_training_info( action_distribution=action_distribution_param, action=action, reward=time_step.reward, discount=time_step.discount, step_type=time_step.step_type, info=policy_step.info) with tape: training_info_ta = tf.nest.map_structure( lambda ta, x: ta.write(counter, x), training_info_ta, final_training_info) training_info = tf.nest.map_structure(lambda ta: ta.stack(), training_info_ta) action_distribution = nested_distributions_from_specs( self._algorithm.action_distribution_spec, training_info.action_distribution) training_info = training_info._replace( action_distribution=action_distribution) loss_info, grads_and_vars = self._algorithm.train_complete( tape, training_info) del tape self._training_summary(training_info, loss_info, grads_and_vars) self._train_step_counter.assign_add(1) return next_time_step, next_state
def _prepare_specs(self, algorithm): """Prepare various tensor specs.""" def extract_spec(nest): return tf.nest.map_structure( lambda t: tf.TensorSpec(t.shape[1:], t.dtype), nest) time_step = self.get_initial_time_step() self._time_step_spec = extract_spec(time_step) self._action_spec = self._env.action_spec() policy_step = algorithm.predict(time_step, self._initial_state) info_spec = extract_spec(policy_step.info) self._pred_policy_step_spec = PolicyStep( action=self._action_spec, state=algorithm.predict_state_spec, info=info_spec) def _to_distribution_spec(spec): if isinstance(spec, tf.TensorSpec): return DistributionSpec(tfp.distributions.Deterministic, input_params_spec={"loc": spec}, sample_spec=spec) return spec self._action_distribution_spec = tf.nest.map_structure( _to_distribution_spec, algorithm.action_distribution_spec) self._action_dist_param_spec = tf.nest.map_structure( lambda spec: spec.input_params_spec, self._action_distribution_spec) self._experience_spec = Experience( step_type=self._time_step_spec.step_type, reward=self._time_step_spec.reward, discount=self._time_step_spec.discount, observation=self._time_step_spec.observation, prev_action=self._action_spec, action=self._action_spec, info=info_spec, action_distribution=self._action_dist_param_spec) action_dist_params = common.zero_tensor_from_nested_spec( self._experience_spec.action_distribution, self._env.batch_size) action_dist = nested_distributions_from_specs( self._action_distribution_spec, action_dist_params) exp = Experience(step_type=time_step.step_type, reward=time_step.reward, discount=time_step.discount, observation=time_step.observation, prev_action=time_step.prev_action, action=time_step.prev_action, info=policy_step.info, action_distribution=action_dist) processed_exp = algorithm.preprocess_experience(exp) self._processed_experience_spec = self._experience_spec._replace( info=extract_spec(processed_exp.info)) policy_step = common.algorithm_step( algorithm, ob_transformer=self._observation_transformer, time_step=exp, state=common.get_initial_policy_state(self._env.batch_size, algorithm.train_state_spec), training=True) info_spec = extract_spec(policy_step.info) self._training_info_spec = make_training_info( action=self._action_spec, action_distribution=self._action_dist_param_spec, step_type=self._time_step_spec.step_type, reward=self._time_step_spec.reward, discount=self._time_step_spec.discount, info=info_spec, collect_info=self._processed_experience_spec.info, collect_action_distribution=self._action_dist_param_spec)