def actor_loss(policy_params, q_params, target_q_params, alpha, transitions, snr_state, key, in_initial_bc_iters): dist_params = networks.policy_network.apply( policy_params, transitions.observation) if in_initial_bc_iters: log_prob = networks.log_prob(dist_params, transitions.action) min_q = 0. actor_loss = -log_prob # No SNR in bc iters sn = 0. new_snr_state = snr_state else: key, sub_key = jax.random.split(key) action = networks.sample(dist_params, sub_key) log_prob = networks.log_prob(dist_params, action) q_action = networks.q_network.apply(q_params, transitions.observation, action) min_q = jnp.min(q_action, axis=-1) actor_loss = alpha * log_prob - min_q # SNR only applied after initial BC iters if self._use_snr: next_dist_params = networks.policy_network.apply( policy_params, transitions.next_observation) next_dist_params = [ next_dist_params._distribution._distribution.loc, next_dist_params._distribution._distribution.scale, ] key, sub_key = jax.random.split(key) sn, (masked_s, C, new_snr_state) = snr_loss_fn( next_dist_params, transitions.observation, transitions.action, transitions.next_observation, transitions.discount, sub_key, snr_state, q_params, target_q_params) actor_loss = actor_loss + snr_alpha * sn else: sn = 0. new_snr_state = snr_state return jnp.mean(actor_loss), (min_q, jnp.mean(log_prob), sn, new_snr_state)
def critic_loss(q_params, policy_params, target_q_params, alpha, transitions, key): q_old_action = networks.q_network.apply(q_params, transitions.observation, transitions.action) next_dist_params = networks.policy_network.apply( policy_params, transitions.next_observation) next_action = networks.sample(next_dist_params, key) next_log_prob = networks.log_prob(next_dist_params, next_action) next_q = networks.q_network.apply(target_q_params, transitions.next_observation, next_action) # next_v = jnp.min(next_q, axis=-1) - alpha * next_log_prob next_v = jnp.min(next_q, axis=-1) target_q = jax.lax.stop_gradient( transitions.reward * reward_scale + transitions.discount * discount * next_v) q_error = q_old_action - jnp.expand_dims(target_q, -1) # q_loss = 0.5 * jnp.mean(jnp.square(q_error)) q_loss = jnp.mean(jnp.square(q_error)) q_loss = q_loss * q_error.shape[-1] return q_loss