def get_td_target( rng: PRNGSequence, state: jnp.ndarray, action: jnp.ndarray, next_state: jnp.ndarray, reward: jnp.ndarray, not_done: jnp.ndarray, discount: float, max_action: float, action_dim: int, actor_params: FrozenDict, critic_target_params: FrozenDict, log_alpha_params: FrozenDict, ) -> jnp.ndarray: next_action, next_log_p = apply_gaussian_policy_model( actor_params, action_dim, max_action, next_state, rng, True, False) target_Q1, target_Q2 = apply_double_critic_model(critic_target_params, next_state, next_action, False) target_Q = (jnp.minimum(target_Q1, target_Q2) - jnp.exp(apply_constant_model(log_alpha_params, -3.5, False)) * next_log_p) target_Q = reward + not_done * discount * target_Q return target_Q
def loss_fn(log_alpha_params): partial_loss_fn = jax.vmap( partial( alpha_loss_fn, apply_constant_model(log_alpha_params, -3.5, False), target_entropy, )) return jnp.mean(partial_loss_fn(log_p))
def loss_fn(mlo, slo, actor_params): # get the distribution of the actor network (current policy) mu, log_sig = apply_gaussian_policy_model( actor_params, action_dim, max_action, state, None, False, True ) sig = jnp.exp(log_sig) # get the distribution of the target network (old policy) target_mu, target_log_sig = apply_gaussian_policy_model( actor_target_params, action_dim, max_action, state, None, False, True ) target_mu = jax.lax.stop_gradient(target_mu) target_log_sig = jax.lax.stop_gradient(target_log_sig) target_sig = jnp.exp(target_log_sig) # get the log likelihooods of the sampled actions according to the # decoupled distributions. described in section 4.2.1 of # Relative Entropy Regularized Policy Iteration # this ensures that the nonparametric policy won't collapse to give # a probability of 1 to the best action, which is a risk when we use # the on-policy distribution to calculate the likelihood. actor_log_prob = gaussian_likelihood(sampled_actions, target_mu, log_sig) actor_log_prob += gaussian_likelihood(sampled_actions, mu, target_log_sig) actor_log_prob = actor_log_prob.transpose((0, 1)) mu_kl = kl_mvg_diag(target_mu, target_sig, mu, target_sig).mean() sig_kl = kl_mvg_diag(target_mu, target_sig, target_mu, sig).mean() mlo = mu_lagrange_step(mlo, eps_mu - jax.lax.stop_gradient(mu_kl)) slo = sig_lagrange_step(slo, eps_sig - jax.lax.stop_gradient(sig_kl)) # maximize the log likelihood, regularized by the divergence between # the target policy and the current policy. the goal here is to fit # the parametric policy to have the minimum divergence with the nonparametric # distribution based on the sampled actions. actor_loss = -(actor_log_prob * weights).sum(axis=1).mean() actor_loss -= jax.lax.stop_gradient( apply_constant_model(mlo.target, 1.0, True) ) * (eps_mu - mu_kl) actor_loss -= jax.lax.stop_gradient( apply_constant_model(slo.target, 100.0, True) ) * (eps_sig - sig_kl) return actor_loss.mean(), (mlo, slo)
def loss_fn(actor_params): actor_action, log_p = apply_gaussian_policy_model( actor_params, action_dim, max_action, state, rng, True, False) q1, q2 = apply_double_critic_model(critic_params, state, actor_action, False) min_q = jnp.minimum(q1, q2) partial_loss_fn = jax.vmap( partial( actor_loss_fn, jax.lax.stop_gradient( apply_constant_model(log_alpha_params, -3.5, False)), ), ) actor_loss = partial_loss_fn(log_p, min_q) return jnp.mean(actor_loss), log_p
def loss_fn(sig_lagrange_params): return jnp.sum(apply_constant_model(sig_lagrange_params, 100.0, True) * reg)