Ejemplo n.º 1
0
    def _actor_train_step(self, exp: Experience, state: DdpgActorState):
        action, actor_state = self._actor_network(
            exp.observation, exp.step_type, network_state=state.actor)

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(action)
            q_value, critic_state = self._critic_network(
                (exp.observation, action), network_state=state.critic)

        dqda = tape.gradient(q_value, action)

        def actor_loss_fn(dqda, action):
            if self._dqda_clipping:
                dqda = tf.clip_by_value(dqda, -self._dqda_clipping,
                                        self._dqda_clipping)
            loss = 0.5 * losses.element_wise_squared_loss(
                tf.stop_gradient(dqda + action), action)
            loss = tf.reduce_sum(loss, axis=list(range(1, len(loss.shape))))
            return loss

        actor_loss = tf.nest.map_structure(actor_loss_fn, dqda, action)
        state = DdpgActorState(actor=actor_state, critic=critic_state)
        info = LossInfo(
            loss=tf.add_n(tf.nest.flatten(actor_loss)), extra=actor_loss)
        return PolicyStep(action=action, state=state, info=info)
Ejemplo n.º 2
0
 def calc_loss(self, training_info: TrainingInfo):
     critic_loss = self._calc_critic_loss(training_info)
     alpha_loss = training_info.info.alpha.loss
     actor_loss = training_info.info.actor.loss
     return LossInfo(loss=actor_loss.loss + critic_loss.loss +
                     alpha_loss.loss,
                     extra=SacLossInfo(actor=actor_loss.extra,
                                       critic=critic_loss.extra,
                                       alpha=alpha_loss.extra))
Ejemplo n.º 3
0
    def __call__(self, training_info: TrainingInfo, value):
        """Cacluate actor critic loss

        The first dimension of all the tensors is time dimension and the second
        dimesion is the batch dimension.

        Args:
            training_info (TrainingInfo): training_info collected by
                (On/Off)PolicyDriver. All tensors in training_info are time-major
            value (tf.Tensor): the time-major tensor for the value at each time
                step
            final_value (tf.Tensor): the value at one step ahead.
        Returns:
            loss_info (LossInfo): with loss_info.extra being ActorCriticLossInfo
        """

        returns, advantages = self._calc_returns_and_advantages(
            training_info, value)

        def _summary():
            with tf.name_scope('ActorCriticLoss'):
                tf.summary.scalar("values", tf.reduce_mean(value))
                tf.summary.scalar("returns", tf.reduce_mean(returns))
                tf.summary.scalar("advantages/mean",
                                  tf.reduce_mean(advantages))
                tf.summary.histogram("advantages/value", advantages)
                tf.summary.scalar("explained_variance_of_return_by_value",
                                  common.explained_variance(value, returns))

        if self._debug_summaries:
            common.run_if(common.should_record_summaries(), _summary)

        if self._normalize_advantages:
            advantages = _normalize_advantages(advantages, axes=(0, 1))

        if self._advantage_clip:
            advantages = tf.clip_by_value(advantages, -self._advantage_clip,
                                          self._advantage_clip)

        pg_loss = self._pg_loss(training_info, tf.stop_gradient(advantages))

        td_loss = self._td_error_loss_fn(tf.stop_gradient(returns), value)

        loss = pg_loss + self._td_loss_weight * td_loss

        entropy_loss = ()
        if self._entropy_regularization is not None:
            entropy, entropy_for_gradient = dist_utils.entropy_with_fallback(
                training_info.action_distribution, self._action_spec)
            entropy_loss = -entropy
            loss -= self._entropy_regularization * entropy_for_gradient

        return LossInfo(loss=loss,
                        extra=ActorCriticLossInfo(td_loss=td_loss,
                                                  pg_loss=pg_loss,
                                                  entropy_loss=entropy_loss))
Ejemplo n.º 4
0
 def _update_loss(loss_info, training_info, name, algorithm):
     if algorithm is None:
         return loss_info
     new_loss_info = algorithm.calc_loss(
         getattr(training_info.info, name))
     return LossInfo(
         loss=add_ignore_empty(loss_info.loss, new_loss_info.loss),
         scalar_loss=add_ignore_empty(loss_info.scalar_loss,
                                      new_loss_info.scalar_loss),
         extra=loss_info.extra._replace(**{name: new_loss_info.extra}))
Ejemplo n.º 5
0
    def calc_loss(self, training_info: TrainingInfo):
        """Calculate loss."""
        self.add_reward_summary("reward", training_info.reward)
        mbp_loss_info = self._mbp.calc_loss(training_info.info.mbp_info)
        mba_loss_info = self._mba.calc_loss(
            training_info._replace(info=training_info.info.mba_info))

        return LossInfo(loss=mbp_loss_info.loss + mba_loss_info.loss,
                        extra=MerlinLossInfo(mbp=mbp_loss_info.extra,
                                             mba=mba_loss_info.extra))
Ejemplo n.º 6
0
    def calc_loss(self, training_info: TrainingInfo):
        critic_loss = self._critic_loss(
            training_info=training_info,
            value=training_info.info.critic.q_value,
            target_value=training_info.info.critic.target_q_value)

        actor_loss = training_info.info.actor_loss

        return LossInfo(
            loss=critic_loss.loss + actor_loss.loss,
            extra=DdpgLossInfo(
                critic=critic_loss.extra, actor=actor_loss.extra))
Ejemplo n.º 7
0
 def decode_step(self, latent_vector, observations):
     """Calculate decoding loss."""
     decoders = tf.nest.flatten(self._decoders)
     observations = tf.nest.flatten(observations)
     decoder_losses = [
         decoder.train_step((latent_vector, obs)).info
         for decoder, obs in zip(decoders, observations)
     ]
     loss = tf.add_n([decoder_loss.loss for decoder_loss in decoder_losses])
     decoder_losses = tf.nest.pack_sequence_as(self._decoders,
                                               decoder_losses)
     return LossInfo(loss=loss, extra=decoder_losses)
Ejemplo n.º 8
0
 def __call__(self, training_info: TrainingInfo, value, target_value):
     returns = value_ops.one_step_discounted_return(
         rewards=training_info.reward,
         values=target_value,
         step_types=training_info.step_type,
         discounts=training_info.discount * self._gamma)
     returns = common.tensor_extend(returns, value[-1])
     if self._debug_summaries:
         with tf.name_scope('OneStepTDLoss'):
             tf.summary.scalar("values", tf.reduce_mean(value))
             tf.summary.scalar("returns", tf.reduce_mean(returns))
     loss = self._td_error_loss_fn(tf.stop_gradient(returns), value)
     return LossInfo(loss=loss, extra=loss)
Ejemplo n.º 9
0
    def _calc_critic_loss(self, training_info):
        critic_info = training_info.info.critic

        target_critic = critic_info.target_critic

        critic_loss1 = self._critic_loss(training_info=training_info,
                                         value=critic_info.critic1,
                                         target_value=target_critic)

        critic_loss2 = self._critic_loss(training_info=training_info,
                                         value=critic_info.critic2,
                                         target_value=target_critic)

        critic_loss = critic_loss1.loss + critic_loss2.loss
        return LossInfo(loss=critic_loss, extra=critic_loss)
Ejemplo n.º 10
0
    def train_step(self, inputs, state: MBPState):
        """Train one step.

        Args:
            inputs (tuple): a tuple of (observation, action)
        """
        observation, _ = inputs
        latent_vector, kld, next_state = self.encode_step(inputs, state)

        # TODO: decoder for action
        decoder_loss = self.decode_step(latent_vector, observation)

        return AlgorithmStep(
            outputs=latent_vector,
            state=next_state,
            info=LossInfo(loss=self._loss_weight * (decoder_loss.loss + kld),
                          extra=MBPLossInfo(decoder=decoder_loss, vae=kld)))
Ejemplo n.º 11
0
 def _alpha_train_step(self, log_pi):
     alpha_loss = self._log_alpha * tf.stop_gradient(-log_pi -
                                                     self._target_entropy)
     info = SacAlphaInfo(loss=LossInfo(loss=alpha_loss, extra=alpha_loss))
     return info
Ejemplo n.º 12
0
    def _actor_train_step(self, exp: Experience, state: SacActorState,
                          action_distribution, action, log_pi):

        if self._is_continuous:
            critic_input = (exp.observation, action)

            with tf.GradientTape(watch_accessed_variables=False) as tape:
                tape.watch(action)
                critic1, critic1_state = self._critic_network1(
                    critic_input,
                    step_type=exp.step_type,
                    network_state=state.critic1)

                critic2, critic2_state = self._critic_network2(
                    critic_input,
                    step_type=exp.step_type,
                    network_state=state.critic2)

                target_q_value = tf.minimum(critic1, critic2)

            dqda = tape.gradient(target_q_value, action)

            def actor_loss_fn(dqda, action):
                if self._dqda_clipping:
                    dqda = tf.clip_by_value(dqda, -self._dqda_clipping,
                                            self._dqda_clipping)
                loss = 0.5 * losses.element_wise_squared_loss(
                    tf.stop_gradient(dqda + action), action)
                loss = tf.reduce_sum(loss,
                                     axis=list(range(1, len(loss.shape))))
                return loss

            actor_loss = tf.nest.map_structure(actor_loss_fn, dqda, action)
            alpha = tf.stop_gradient(tf.exp(self._log_alpha))
            actor_loss += alpha * log_pi
        else:
            critic1, critic1_state = self._critic_network1(
                exp.observation,
                step_type=exp.step_type,
                network_state=state.critic1)

            critic2, critic2_state = self._critic_network2(
                exp.observation,
                step_type=exp.step_type,
                network_state=state.critic2)

            assert isinstance(
                action_distribution, tfp.distributions.Categorical), \
                "Only `tfp.distributions.Categorical` was supported, received:" + str(type(action_distribution))

            action_probs = action_distribution.probs
            log_action_probs = tf.math.log(action_probs + 1e-8)

            target_q_value = tf.stop_gradient(tf.minimum(critic1, critic2))
            alpha = tf.stop_gradient(tf.exp(self._log_alpha))
            actor_loss = tf.reduce_mean(
                action_probs * (alpha * log_action_probs - target_q_value),
                axis=-1)

        state = SacActorState(critic1=critic1_state, critic2=critic2_state)
        info = SacActorInfo(loss=LossInfo(loss=actor_loss, extra=actor_loss))
        return state, info