Exemple #1
0
def get_td_target(
    rng: PRNGSequence,
    state: jnp.ndarray,
    action: jnp.ndarray,
    next_state: jnp.ndarray,
    reward: jnp.ndarray,
    not_done: jnp.ndarray,
    discount: float,
    max_action: float,
    action_dim: int,
    actor_params: FrozenDict,
    critic_target_params: FrozenDict,
    log_alpha_params: FrozenDict,
) -> jnp.ndarray:
    next_action, next_log_p = apply_gaussian_policy_model(
        actor_params, action_dim, max_action, next_state, rng, True, False)

    target_Q1, target_Q2 = apply_double_critic_model(critic_target_params,
                                                     next_state, next_action,
                                                     False)
    target_Q = (jnp.minimum(target_Q1, target_Q2) -
                jnp.exp(apply_constant_model(log_alpha_params, -3.5, False)) *
                next_log_p)
    target_Q = reward + not_done * discount * target_Q

    return target_Q
Exemple #2
0
 def loss_fn(log_alpha_params):
     partial_loss_fn = jax.vmap(
         partial(
             alpha_loss_fn,
             apply_constant_model(log_alpha_params, -3.5, False),
             target_entropy,
         ))
     return jnp.mean(partial_loss_fn(log_p))
Exemple #3
0
    def loss_fn(mlo, slo, actor_params):
        # get the distribution of the actor network (current policy)
        mu, log_sig = apply_gaussian_policy_model(
            actor_params, action_dim, max_action, state, None, False, True
        )
        sig = jnp.exp(log_sig)
        # get the distribution of the target network (old policy)
        target_mu, target_log_sig = apply_gaussian_policy_model(
            actor_target_params, action_dim, max_action, state, None, False, True
        )
        target_mu = jax.lax.stop_gradient(target_mu)
        target_log_sig = jax.lax.stop_gradient(target_log_sig)
        target_sig = jnp.exp(target_log_sig)

        # get the log likelihooods of the sampled actions according to the
        # decoupled distributions. described in section 4.2.1 of
        # Relative Entropy Regularized Policy Iteration
        # this ensures that the nonparametric policy won't collapse to give
        # a probability of 1 to the best action, which is a risk when we use
        # the on-policy distribution to calculate the likelihood.
        actor_log_prob = gaussian_likelihood(sampled_actions, target_mu, log_sig)
        actor_log_prob += gaussian_likelihood(sampled_actions, mu, target_log_sig)
        actor_log_prob = actor_log_prob.transpose((0, 1))

        mu_kl = kl_mvg_diag(target_mu, target_sig, mu, target_sig).mean()
        sig_kl = kl_mvg_diag(target_mu, target_sig, target_mu, sig).mean()

        mlo = mu_lagrange_step(mlo, eps_mu - jax.lax.stop_gradient(mu_kl))
        slo = sig_lagrange_step(slo, eps_sig - jax.lax.stop_gradient(sig_kl))

        # maximize the log likelihood, regularized by the divergence between
        # the target policy and the current policy. the goal here is to fit
        # the parametric policy to have the minimum divergence with the nonparametric
        # distribution based on the sampled actions.
        actor_loss = -(actor_log_prob * weights).sum(axis=1).mean()
        actor_loss -= jax.lax.stop_gradient(
            apply_constant_model(mlo.target, 1.0, True)
        ) * (eps_mu - mu_kl)
        actor_loss -= jax.lax.stop_gradient(
            apply_constant_model(slo.target, 100.0, True)
        ) * (eps_sig - sig_kl)
        return actor_loss.mean(), (mlo, slo)
Exemple #4
0
 def loss_fn(actor_params):
     actor_action, log_p = apply_gaussian_policy_model(
         actor_params, action_dim, max_action, state, rng, True, False)
     q1, q2 = apply_double_critic_model(critic_params, state, actor_action,
                                        False)
     min_q = jnp.minimum(q1, q2)
     partial_loss_fn = jax.vmap(
         partial(
             actor_loss_fn,
             jax.lax.stop_gradient(
                 apply_constant_model(log_alpha_params, -3.5, False)),
         ), )
     actor_loss = partial_loss_fn(log_p, min_q)
     return jnp.mean(actor_loss), log_p
Exemple #5
0
 def loss_fn(sig_lagrange_params):
     return jnp.sum(apply_constant_model(sig_lagrange_params, 100.0, True) * reg)