def loss_fn(params, rng_input, target_quantile_vals, target_r, target_next_r): def online(state): return network.apply(params, state, num_quantiles=num_tau_samples, rng=rng_input) model_output = jax.vmap(online)(states) quantile_values = model_output.quantile_values quantiles = model_output.quantiles representations = model_output.representation representations = jnp.squeeze(representations) chosen_action_quantile_values = jax.vmap( lambda x, y: x[:, y][:, None])(quantile_values, actions) # Shape of bellman_erors and huber_loss: # batch_size x num_tau_prime_samples x num_tau_samples x 1. bellman_errors = (target_quantile_vals[:, :, None, :] - chosen_action_quantile_values[:, None, :, :]) # The huber loss (see Section 2.3 of the paper) is defined via two cases: # case_one: |bellman_errors| <= kappa # case_two: |bellman_errors| > kappa huber_loss_case_one = ( (jnp.abs(bellman_errors) <= kappa).astype(jnp.float32) * 0.5 * bellman_errors**2) huber_loss_case_two = ( (jnp.abs(bellman_errors) > kappa).astype(jnp.float32) * kappa * (jnp.abs(bellman_errors) - 0.5 * kappa)) huber_loss = huber_loss_case_one + huber_loss_case_two # Tile by num_tau_prime_samples along a new dimension. Shape is now # batch_size x num_tau_prime_samples x num_tau_samples x 1. # These quantiles will be used for computation of the quantile huber loss # below (see section 2.3 of the paper). quantiles = jnp.tile(quantiles[:, None, :, :], [1, num_tau_prime_samples, 1, 1]).astype( jnp.float32) # Shape: batch_size x num_tau_prime_samples x num_tau_samples x 1. quantile_huber_loss = (jnp.abs(quantiles - jax.lax.stop_gradient( (bellman_errors < 0).astype(jnp.float32))) * huber_loss) / kappa # Sum over current quantile value (num_tau_samples) dimension, # average over target quantile value (num_tau_prime_samples) dimension. # Shape: batch_size x num_tau_prime_samples x 1. quantile_huber_loss = jnp.sum(quantile_huber_loss, axis=2) quantile_huber_loss = jnp.mean(quantile_huber_loss, axis=1) online_dist = metric_utils.representation_distances( representations, target_r, distance_fn) target_dist = metric_utils.target_distances(target_next_r, rewards, distance_fn, cumulative_gamma) metric_loss = jnp.mean( jax.vmap(losses.huber_loss)(online_dist, target_dist)) loss = ((1. - mico_weight) * quantile_huber_loss + mico_weight * metric_loss) return jnp.mean(loss), (jnp.mean(quantile_huber_loss), metric_loss)
def loss_fn(params, bellman_target, target_r, target_next_r): def q_online(state): return network_def.apply(params, state) model_output = jax.vmap(q_online)(states) q_values = model_output.q_values q_values = jnp.squeeze(q_values) representations = model_output.representation representations = jnp.squeeze(representations) replay_chosen_q = jax.vmap(lambda x, y: x[y])(q_values, actions) bellman_loss = jnp.mean( jax.vmap(losses.mse_loss)(bellman_target, replay_chosen_q)) online_dist = metric_utils.representation_distances( representations, target_r, distance_fn) target_dist = metric_utils.target_distances(target_next_r, rewards, distance_fn, cumulative_gamma) metric_loss = jnp.mean( jax.vmap(losses.huber_loss)(online_dist, target_dist)) loss = ((1. - mico_weight) * bellman_loss + mico_weight * metric_loss) return jnp.mean(loss), (bellman_loss, metric_loss)
def loss_fn(params, bellman_target, target_r, target_next_r): def q_online(state): return network_def.apply(params, state) model_output = jax.vmap(q_online)(states) logits = model_output.logits logits = jnp.squeeze(logits) representations = model_output.representation representations = jnp.squeeze(representations) # Fetch the logits for its selected action. We use vmap to perform this # indexing across the batch. chosen_action_logits = jax.vmap(lambda x, y: x[y])(logits, actions) bellman_errors = (bellman_target[:, None, :] - chosen_action_logits[:, :, None] ) # Input `u' of Eq. 9. # Eq. 9 of paper. huber_loss = ( (jnp.abs(bellman_errors) <= kappa).astype(jnp.float32) * 0.5 * bellman_errors**2 + (jnp.abs(bellman_errors) > kappa).astype(jnp.float32) * kappa * (jnp.abs(bellman_errors) - 0.5 * kappa)) tau_hat = ((jnp.arange(num_atoms, dtype=jnp.float32) + 0.5) / num_atoms ) # Quantile midpoints. See Lemma 2 of paper. # Eq. 10 of paper. tau_bellman_diff = jnp.abs(tau_hat[None, :, None] - (bellman_errors < 0).astype(jnp.float32)) quantile_huber_loss = tau_bellman_diff * huber_loss # Sum over tau dimension, average over target value dimension. quantile_loss = jnp.sum(jnp.mean(quantile_huber_loss, 2), 1) online_dist = metric_utils.representation_distances( representations, target_r, distance_fn) target_dist = metric_utils.target_distances(target_next_r, rewards, distance_fn, cumulative_gamma) metric_loss = jnp.mean( jax.vmap(losses.huber_loss)(online_dist, target_dist)) loss = ((1. - mico_weight) * quantile_loss + mico_weight * metric_loss) return jnp.mean(loss), (loss, jnp.mean(quantile_loss), metric_loss)
def loss_fn(params, bellman_target, loss_multipliers, target_r, target_next_r): def q_online(state): return network_def.apply(params, state, support) model_output = jax.vmap(q_online)(states) logits = model_output.logits logits = jnp.squeeze(logits) representations = model_output.representation representations = jnp.squeeze(representations) # Fetch the logits for its selected action. We use vmap to perform this # indexing across the batch. chosen_action_logits = jax.vmap(lambda x, y: x[y])(logits, actions) c51_loss = jax.vmap(losses.softmax_cross_entropy_loss_with_logits)( bellman_target, chosen_action_logits) c51_loss *= loss_multipliers online_dist, norm_average, angular_distance = ( metric_utils.representation_distances( representations, target_r, distance_fn, return_distance_components=True)) target_dist = metric_utils.target_distances(target_next_r, rewards, distance_fn, cumulative_gamma) metric_loss = jnp.mean( jax.vmap(losses.huber_loss)(online_dist, target_dist)) loss = ((1. - mico_weight) * c51_loss + mico_weight * metric_loss) aux_losses = { 'loss': loss, 'mean_loss': jnp.mean(loss), 'c51_loss': jnp.mean(c51_loss), 'metric_loss': metric_loss, 'norm_average': jnp.mean(norm_average), 'angular_distance': jnp.mean(angular_distance), } return jnp.mean(loss), aux_losses
def loss_fn(params, log_alpha): """Calculates the loss for one transition.""" def critic_online(state, action): return network_def.apply(params, state, action, method=network_def.critic) # We use frozen_params so that gradients can flow back to the actor without # being used to update the critic. def frozen_critic_online(state, action): return network_def.apply(frozen_params, state, action, method=network_def.critic) def actor_online(state, action): return network_def.apply(params, state, action, method=network_def.actor) def q_target(next_state, rng): return network_def.apply(target_params, next_state, rng, True) # J_Q(\theta) from equation (5) in paper. q_values_1, q_values_2, representations = jax.vmap(critic_online)( states, actions) q_values_1 = jnp.squeeze(q_values_1) q_values_2 = jnp.squeeze(q_values_2) representations = jnp.squeeze(representations) brng1 = jnp.stack(jax.random.split(rng1, num=batch_size)) target_outputs = jax.vmap(q_target)(next_states, brng1) target_q_values_1, target_q_values_2 = target_outputs.critic target_next_r = target_outputs.representation target_q_values_1 = jnp.squeeze(target_q_values_1) target_q_values_2 = jnp.squeeze(target_q_values_2) target_next_r = jnp.squeeze(target_next_r) target_q_values = jnp.minimum(target_q_values_1, target_q_values_2) alpha_value = jnp.exp(log_alpha) log_probs = target_outputs.actor.log_probability targets = reward_scale_factor * rewards + cumulative_gamma * ( target_q_values - alpha_value * log_probs) * (1. - terminals) targets = jax.lax.stop_gradient(targets) critic_loss_1 = jax.vmap(losses.mse_loss)(q_values_1, targets) critic_loss_2 = jax.vmap(losses.mse_loss)(q_values_2, targets) critic_loss = jnp.mean(critic_loss_1 + critic_loss_2) # J_{\pi}(\phi) from equation (9) in paper. brng2 = jnp.stack(jax.random.split(rng2, num=batch_size)) mean_actions, sampled_actions, action_log_probs = jax.vmap( actor_online)(states, brng2) q_values_no_grad_1, q_values_no_grad_2, target_r = jax.vmap( frozen_critic_online)(states, sampled_actions) q_values_no_grad_1 = jnp.squeeze(q_values_no_grad_1) q_values_no_grad_2 = jnp.squeeze(q_values_no_grad_2) target_r = jnp.squeeze(target_r) no_grad_q_values = jnp.minimum(q_values_no_grad_1, q_values_no_grad_2) alpha_value = jnp.exp(jax.lax.stop_gradient(log_alpha)) policy_loss = jnp.mean(alpha_value * action_log_probs - no_grad_q_values) # J(\alpha) from equation (18) in paper. entropy_diffs = -action_log_probs - target_entropy alpha_loss = jnp.mean(log_alpha * jax.lax.stop_gradient(entropy_diffs)) # MICo loss. distance_fn = metric_utils.cosine_distance online_dist = metric_utils.representation_distances( representations, target_r, distance_fn) target_dist = metric_utils.target_distances(target_next_r, rewards, distance_fn, cumulative_gamma) mico_loss = jnp.mean( jax.vmap(losses.huber_loss)(online_dist, target_dist)) # Giving a smaller weight to the critic empirically gives better results sac_loss = 0.5 * critic_loss + 1.0 * policy_loss + 1.0 * alpha_loss combined_loss = (1. - mico_weight) * sac_loss + mico_weight * mico_loss return combined_loss, { 'mico_loss': mico_loss, 'critic_loss': critic_loss, 'policy_loss': policy_loss, 'alpha_loss': alpha_loss, 'critic_value_1': jnp.mean(q_values_1), 'critic_value_2': jnp.mean(q_values_2), 'target_value_1': jnp.mean(target_q_values_1), 'target_value_2': jnp.mean(target_q_values_2), 'mean_action': jnp.mean(mean_actions, axis=0) }