def get_td_target( rng: PRNGSequence, state: jnp.ndarray, action: jnp.ndarray, next_state: jnp.ndarray, reward: jnp.ndarray, not_done: jnp.ndarray, discount: float, max_action: float, action_dim: int, actor_params: FrozenDict, critic_target_params: FrozenDict, log_alpha_params: FrozenDict, ) -> jnp.ndarray: next_action, next_log_p = apply_gaussian_policy_model( actor_params, action_dim, max_action, next_state, rng, True, False) target_Q1, target_Q2 = apply_double_critic_model(critic_target_params, next_state, next_action, False) target_Q = (jnp.minimum(target_Q1, target_Q2) - jnp.exp(apply_constant_model(log_alpha_params, -3.5, False)) * next_log_p) target_Q = reward + not_done * discount * target_Q return target_Q
def get_td_target( rng: PRNGSequence, state: jnp.ndarray, action: jnp.ndarray, next_state: jnp.ndarray, reward: jnp.ndarray, not_done: jnp.ndarray, discount: float, max_action: float, action_dim: int, actor_target_params: FrozenDict, critic_target_params: FrozenDict, ) -> jnp.ndarray: mu, log_sig = apply_gaussian_policy_model( actor_target_params, action_dim, max_action, next_state, None, False, True ) next_action = mu + jnp.exp(log_sig) * random.normal(rng, mu.shape) next_action = max_action * nn.tanh(next_action) target_Q1, target_Q2 = apply_double_critic_model( critic_target_params, next_state, next_action, False ) target_Q = jnp.minimum(target_Q1, target_Q2) target_Q = reward + not_done * discount * target_Q return target_Q
def loss_fn(actor_params): actor_action, log_p = apply_gaussian_policy_model( actor_params, action_dim, max_action, state, rng, True, False) q1, q2 = apply_double_critic_model(critic_params, state, actor_action, False) min_q = jnp.minimum(q1, q2) partial_loss_fn = jax.vmap( partial( actor_loss_fn, jax.lax.stop_gradient( apply_constant_model(log_alpha_params, -3.5, False)), ), ) actor_loss = partial_loss_fn(log_p, min_q) return jnp.mean(actor_loss), log_p
def sample_actions_and_evaluate( rng: PRNGSequence, actor_target_params: FrozenDict, critic_target_params: FrozenDict, max_action: float, action_dim: int, state: jnp.ndarray, batch_size: int, action_sample_size: int, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ To build our nonparametric policy, q(s, a), we sample `action_sample_size` actions from each policy in the batch and evaluate their Q-values. """ # get the policy distribution for each state and sample `action_sample_size` # actions from each mu, log_sig = apply_gaussian_policy_model( actor_target_params, action_dim, max_action, state, None, False, True ) mu = jnp.expand_dims(mu, axis=1) sig = jnp.expand_dims(jnp.exp(log_sig), axis=1) sampled_actions = ( mu + random.normal(rng, (batch_size, action_sample_size, action_dim)) * sig ) sampled_actions = sampled_actions.reshape( (batch_size * action_sample_size, action_dim) ) sampled_actions = jax.lax.stop_gradient(sampled_actions) states_repeated = jnp.repeat(state, action_sample_size, axis=0) # evaluate each of the sampled actions at their corresponding state # we keep the `sampled_actions` array unnquashed because we need to calcuate # the log probabilities using it, but we pass the squashed actions to the critic Q1 = apply_double_critic_model( critic_target_params, states_repeated, max_action * nn.tanh(sampled_actions), True, ) Q1 = Q1.reshape((batch_size, action_sample_size)) Q1 = jax.lax.stop_gradient(Q1) return Q1, sampled_actions
def sample_actions_and_evaluate( rng: PRNGSequence, actor_target_params: FrozenDict, critic_target_params: FrozenDict, max_action: float, action_dim: int, state: jnp.ndarray, batch_size: int, action_sample_size: int, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ To build our nonparametric policy, q(s, a), we sample `action_sample_size` actions from each policy in the batch and evaluate their Q-values. """ state_dim = state.shape[-1] # get the policy distribution for each state and sample `action_sample_size` # actions from each mu, log_sig = apply_gaussian_policy_model( actor_target_params, state_dim, max_action, state, None, False, True ) sig = jnp.exp(log_sig) sampled_actions = mu + random.normal(rng, (batch_size, action_sample_size)) * sig sampled_actions = max_action * nn.tanh(sampled_actions) sampled_actions = sampled_actions.reshape( (batch_size * action_sample_size, action_dim) ) sampled_actions = jax.lax.stop_gradient(sampled_actions) states_repeated = jnp.repeat(state, action_sample_size, axis=0) # evaluate each of the sampled actions at their corresponding state Q1 = apply_double_critic_model( critic_target_params, states_repeated, sampled_actions, True ) Q1 = Q1.reshape((batch_size, action_sample_size)) Q1 = jax.lax.stop_gradient(Q1) return Q1, sampled_actions
def loss_fn(critic_params): current_Q1, current_Q2 = apply_double_critic_model( critic_params, state, action, False) critic_loss = double_mse(current_Q1, current_Q2, target_Q) return jnp.mean(critic_loss)