def _actor_train_step(self, exp: Experience, state: DdpgActorState): action, actor_state = self._actor_network( exp.observation, exp.step_type, network_state=state.actor) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(action) q_value, critic_state = self._critic_network( (exp.observation, action), network_state=state.critic) dqda = tape.gradient(q_value, action) def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = tf.clip_by_value(dqda, -self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( tf.stop_gradient(dqda + action), action) loss = tf.reduce_sum(loss, axis=list(range(1, len(loss.shape)))) return loss actor_loss = tf.nest.map_structure(actor_loss_fn, dqda, action) state = DdpgActorState(actor=actor_state, critic=critic_state) info = LossInfo( loss=tf.add_n(tf.nest.flatten(actor_loss)), extra=actor_loss) return PolicyStep(action=action, state=state, info=info)
def calc_loss(self, training_info: TrainingInfo): critic_loss = self._calc_critic_loss(training_info) alpha_loss = training_info.info.alpha.loss actor_loss = training_info.info.actor.loss return LossInfo(loss=actor_loss.loss + critic_loss.loss + alpha_loss.loss, extra=SacLossInfo(actor=actor_loss.extra, critic=critic_loss.extra, alpha=alpha_loss.extra))
def __call__(self, training_info: TrainingInfo, value): """Cacluate actor critic loss The first dimension of all the tensors is time dimension and the second dimesion is the batch dimension. Args: training_info (TrainingInfo): training_info collected by (On/Off)PolicyDriver. All tensors in training_info are time-major value (tf.Tensor): the time-major tensor for the value at each time step final_value (tf.Tensor): the value at one step ahead. Returns: loss_info (LossInfo): with loss_info.extra being ActorCriticLossInfo """ returns, advantages = self._calc_returns_and_advantages( training_info, value) def _summary(): with tf.name_scope('ActorCriticLoss'): tf.summary.scalar("values", tf.reduce_mean(value)) tf.summary.scalar("returns", tf.reduce_mean(returns)) tf.summary.scalar("advantages/mean", tf.reduce_mean(advantages)) tf.summary.histogram("advantages/value", advantages) tf.summary.scalar("explained_variance_of_return_by_value", common.explained_variance(value, returns)) if self._debug_summaries: common.run_if(common.should_record_summaries(), _summary) if self._normalize_advantages: advantages = _normalize_advantages(advantages, axes=(0, 1)) if self._advantage_clip: advantages = tf.clip_by_value(advantages, -self._advantage_clip, self._advantage_clip) pg_loss = self._pg_loss(training_info, tf.stop_gradient(advantages)) td_loss = self._td_error_loss_fn(tf.stop_gradient(returns), value) loss = pg_loss + self._td_loss_weight * td_loss entropy_loss = () if self._entropy_regularization is not None: entropy, entropy_for_gradient = dist_utils.entropy_with_fallback( training_info.action_distribution, self._action_spec) entropy_loss = -entropy loss -= self._entropy_regularization * entropy_for_gradient return LossInfo(loss=loss, extra=ActorCriticLossInfo(td_loss=td_loss, pg_loss=pg_loss, entropy_loss=entropy_loss))
def _update_loss(loss_info, training_info, name, algorithm): if algorithm is None: return loss_info new_loss_info = algorithm.calc_loss( getattr(training_info.info, name)) return LossInfo( loss=add_ignore_empty(loss_info.loss, new_loss_info.loss), scalar_loss=add_ignore_empty(loss_info.scalar_loss, new_loss_info.scalar_loss), extra=loss_info.extra._replace(**{name: new_loss_info.extra}))
def calc_loss(self, training_info: TrainingInfo): """Calculate loss.""" self.add_reward_summary("reward", training_info.reward) mbp_loss_info = self._mbp.calc_loss(training_info.info.mbp_info) mba_loss_info = self._mba.calc_loss( training_info._replace(info=training_info.info.mba_info)) return LossInfo(loss=mbp_loss_info.loss + mba_loss_info.loss, extra=MerlinLossInfo(mbp=mbp_loss_info.extra, mba=mba_loss_info.extra))
def calc_loss(self, training_info: TrainingInfo): critic_loss = self._critic_loss( training_info=training_info, value=training_info.info.critic.q_value, target_value=training_info.info.critic.target_q_value) actor_loss = training_info.info.actor_loss return LossInfo( loss=critic_loss.loss + actor_loss.loss, extra=DdpgLossInfo( critic=critic_loss.extra, actor=actor_loss.extra))
def decode_step(self, latent_vector, observations): """Calculate decoding loss.""" decoders = tf.nest.flatten(self._decoders) observations = tf.nest.flatten(observations) decoder_losses = [ decoder.train_step((latent_vector, obs)).info for decoder, obs in zip(decoders, observations) ] loss = tf.add_n([decoder_loss.loss for decoder_loss in decoder_losses]) decoder_losses = tf.nest.pack_sequence_as(self._decoders, decoder_losses) return LossInfo(loss=loss, extra=decoder_losses)
def __call__(self, training_info: TrainingInfo, value, target_value): returns = value_ops.one_step_discounted_return( rewards=training_info.reward, values=target_value, step_types=training_info.step_type, discounts=training_info.discount * self._gamma) returns = common.tensor_extend(returns, value[-1]) if self._debug_summaries: with tf.name_scope('OneStepTDLoss'): tf.summary.scalar("values", tf.reduce_mean(value)) tf.summary.scalar("returns", tf.reduce_mean(returns)) loss = self._td_error_loss_fn(tf.stop_gradient(returns), value) return LossInfo(loss=loss, extra=loss)
def _calc_critic_loss(self, training_info): critic_info = training_info.info.critic target_critic = critic_info.target_critic critic_loss1 = self._critic_loss(training_info=training_info, value=critic_info.critic1, target_value=target_critic) critic_loss2 = self._critic_loss(training_info=training_info, value=critic_info.critic2, target_value=target_critic) critic_loss = critic_loss1.loss + critic_loss2.loss return LossInfo(loss=critic_loss, extra=critic_loss)
def train_step(self, inputs, state: MBPState): """Train one step. Args: inputs (tuple): a tuple of (observation, action) """ observation, _ = inputs latent_vector, kld, next_state = self.encode_step(inputs, state) # TODO: decoder for action decoder_loss = self.decode_step(latent_vector, observation) return AlgorithmStep( outputs=latent_vector, state=next_state, info=LossInfo(loss=self._loss_weight * (decoder_loss.loss + kld), extra=MBPLossInfo(decoder=decoder_loss, vae=kld)))
def _alpha_train_step(self, log_pi): alpha_loss = self._log_alpha * tf.stop_gradient(-log_pi - self._target_entropy) info = SacAlphaInfo(loss=LossInfo(loss=alpha_loss, extra=alpha_loss)) return info
def _actor_train_step(self, exp: Experience, state: SacActorState, action_distribution, action, log_pi): if self._is_continuous: critic_input = (exp.observation, action) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(action) critic1, critic1_state = self._critic_network1( critic_input, step_type=exp.step_type, network_state=state.critic1) critic2, critic2_state = self._critic_network2( critic_input, step_type=exp.step_type, network_state=state.critic2) target_q_value = tf.minimum(critic1, critic2) dqda = tape.gradient(target_q_value, action) def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = tf.clip_by_value(dqda, -self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( tf.stop_gradient(dqda + action), action) loss = tf.reduce_sum(loss, axis=list(range(1, len(loss.shape)))) return loss actor_loss = tf.nest.map_structure(actor_loss_fn, dqda, action) alpha = tf.stop_gradient(tf.exp(self._log_alpha)) actor_loss += alpha * log_pi else: critic1, critic1_state = self._critic_network1( exp.observation, step_type=exp.step_type, network_state=state.critic1) critic2, critic2_state = self._critic_network2( exp.observation, step_type=exp.step_type, network_state=state.critic2) assert isinstance( action_distribution, tfp.distributions.Categorical), \ "Only `tfp.distributions.Categorical` was supported, received:" + str(type(action_distribution)) action_probs = action_distribution.probs log_action_probs = tf.math.log(action_probs + 1e-8) target_q_value = tf.stop_gradient(tf.minimum(critic1, critic2)) alpha = tf.stop_gradient(tf.exp(self._log_alpha)) actor_loss = tf.reduce_mean( action_probs * (alpha * log_action_probs - target_q_value), axis=-1) state = SacActorState(critic1=critic1_state, critic2=critic2_state) info = SacActorInfo(loss=LossInfo(loss=actor_loss, extra=actor_loss)) return state, info