def loss_fn(params, bellman_target, target_r, target_next_r):
        def q_online(state):
            return network_def.apply(params, state)

        model_output = jax.vmap(q_online)(states)
        q_values = model_output.q_values
        q_values = jnp.squeeze(q_values)
        representations = model_output.representation
        representations = jnp.squeeze(representations)
        replay_chosen_q = jax.vmap(lambda x, y: x[y])(q_values, actions)
        bellman_loss = jnp.mean(
            jax.vmap(losses.mse_loss)(bellman_target, replay_chosen_q))
        online_dist = metric_utils.representation_distances(
            representations, target_r, distance_fn)
        target_dist = metric_utils.target_distances(target_next_r, rewards,
                                                    distance_fn,
                                                    cumulative_gamma)
        metric_loss = jnp.mean(
            jax.vmap(losses.huber_loss)(online_dist, target_dist))
        loss = ((1. - mico_weight) * bellman_loss + mico_weight * metric_loss)
        return jnp.mean(loss), (bellman_loss, metric_loss)
    def loss_fn(params, bellman_target, target_r, target_next_r):
        def q_online(state):
            return network_def.apply(params, state)

        model_output = jax.vmap(q_online)(states)
        logits = model_output.logits
        logits = jnp.squeeze(logits)
        representations = model_output.representation
        representations = jnp.squeeze(representations)
        # Fetch the logits for its selected action. We use vmap to perform this
        # indexing across the batch.
        chosen_action_logits = jax.vmap(lambda x, y: x[y])(logits, actions)
        bellman_errors = (bellman_target[:, None, :] -
                          chosen_action_logits[:, :, None]
                          )  # Input `u' of Eq. 9.
        # Eq. 9 of paper.
        huber_loss = (
            (jnp.abs(bellman_errors) <= kappa).astype(jnp.float32) * 0.5 *
            bellman_errors**2 +
            (jnp.abs(bellman_errors) > kappa).astype(jnp.float32) * kappa *
            (jnp.abs(bellman_errors) - 0.5 * kappa))

        tau_hat = ((jnp.arange(num_atoms, dtype=jnp.float32) + 0.5) / num_atoms
                   )  # Quantile midpoints.  See Lemma 2 of paper.
        # Eq. 10 of paper.
        tau_bellman_diff = jnp.abs(tau_hat[None, :, None] -
                                   (bellman_errors < 0).astype(jnp.float32))
        quantile_huber_loss = tau_bellman_diff * huber_loss
        # Sum over tau dimension, average over target value dimension.
        quantile_loss = jnp.sum(jnp.mean(quantile_huber_loss, 2), 1)
        online_dist = metric_utils.representation_distances(
            representations, target_r, distance_fn)
        target_dist = metric_utils.target_distances(target_next_r, rewards,
                                                    distance_fn,
                                                    cumulative_gamma)
        metric_loss = jnp.mean(
            jax.vmap(losses.huber_loss)(online_dist, target_dist))
        loss = ((1. - mico_weight) * quantile_loss + mico_weight * metric_loss)
        return jnp.mean(loss), (loss, jnp.mean(quantile_loss), metric_loss)
    def loss_fn(params, bellman_target, loss_multipliers, target_r,
                target_next_r):
        def q_online(state):
            return network_def.apply(params, state, support)

        model_output = jax.vmap(q_online)(states)
        logits = model_output.logits
        logits = jnp.squeeze(logits)
        representations = model_output.representation
        representations = jnp.squeeze(representations)
        # Fetch the logits for its selected action. We use vmap to perform this
        # indexing across the batch.
        chosen_action_logits = jax.vmap(lambda x, y: x[y])(logits, actions)
        c51_loss = jax.vmap(losses.softmax_cross_entropy_loss_with_logits)(
            bellman_target, chosen_action_logits)
        c51_loss *= loss_multipliers
        online_dist, norm_average, angular_distance = (
            metric_utils.representation_distances(
                representations,
                target_r,
                distance_fn,
                return_distance_components=True))
        target_dist = metric_utils.target_distances(target_next_r, rewards,
                                                    distance_fn,
                                                    cumulative_gamma)
        metric_loss = jnp.mean(
            jax.vmap(losses.huber_loss)(online_dist, target_dist))
        loss = ((1. - mico_weight) * c51_loss + mico_weight * metric_loss)
        aux_losses = {
            'loss': loss,
            'mean_loss': jnp.mean(loss),
            'c51_loss': jnp.mean(c51_loss),
            'metric_loss': metric_loss,
            'norm_average': jnp.mean(norm_average),
            'angular_distance': jnp.mean(angular_distance),
        }
        return jnp.mean(loss), aux_losses
Beispiel #4
0
    def loss_fn(params, log_alpha):
        """Calculates the loss for one transition."""
        def critic_online(state, action):
            return network_def.apply(params,
                                     state,
                                     action,
                                     method=network_def.critic)

        # We use frozen_params so that gradients can flow back to the actor without
        # being used to update the critic.
        def frozen_critic_online(state, action):
            return network_def.apply(frozen_params,
                                     state,
                                     action,
                                     method=network_def.critic)

        def actor_online(state, action):
            return network_def.apply(params,
                                     state,
                                     action,
                                     method=network_def.actor)

        def q_target(next_state, rng):
            return network_def.apply(target_params, next_state, rng, True)

        # J_Q(\theta) from equation (5) in paper.
        q_values_1, q_values_2, representations = jax.vmap(critic_online)(
            states, actions)
        q_values_1 = jnp.squeeze(q_values_1)
        q_values_2 = jnp.squeeze(q_values_2)
        representations = jnp.squeeze(representations)

        brng1 = jnp.stack(jax.random.split(rng1, num=batch_size))
        target_outputs = jax.vmap(q_target)(next_states, brng1)
        target_q_values_1, target_q_values_2 = target_outputs.critic
        target_next_r = target_outputs.representation
        target_q_values_1 = jnp.squeeze(target_q_values_1)
        target_q_values_2 = jnp.squeeze(target_q_values_2)
        target_next_r = jnp.squeeze(target_next_r)
        target_q_values = jnp.minimum(target_q_values_1, target_q_values_2)

        alpha_value = jnp.exp(log_alpha)
        log_probs = target_outputs.actor.log_probability
        targets = reward_scale_factor * rewards + cumulative_gamma * (
            target_q_values - alpha_value * log_probs) * (1. - terminals)
        targets = jax.lax.stop_gradient(targets)
        critic_loss_1 = jax.vmap(losses.mse_loss)(q_values_1, targets)
        critic_loss_2 = jax.vmap(losses.mse_loss)(q_values_2, targets)
        critic_loss = jnp.mean(critic_loss_1 + critic_loss_2)

        # J_{\pi}(\phi) from equation (9) in paper.
        brng2 = jnp.stack(jax.random.split(rng2, num=batch_size))
        mean_actions, sampled_actions, action_log_probs = jax.vmap(
            actor_online)(states, brng2)

        q_values_no_grad_1, q_values_no_grad_2, target_r = jax.vmap(
            frozen_critic_online)(states, sampled_actions)
        q_values_no_grad_1 = jnp.squeeze(q_values_no_grad_1)
        q_values_no_grad_2 = jnp.squeeze(q_values_no_grad_2)
        target_r = jnp.squeeze(target_r)
        no_grad_q_values = jnp.minimum(q_values_no_grad_1, q_values_no_grad_2)
        alpha_value = jnp.exp(jax.lax.stop_gradient(log_alpha))
        policy_loss = jnp.mean(alpha_value * action_log_probs -
                               no_grad_q_values)

        # J(\alpha) from equation (18) in paper.
        entropy_diffs = -action_log_probs - target_entropy
        alpha_loss = jnp.mean(log_alpha * jax.lax.stop_gradient(entropy_diffs))

        # MICo loss.
        distance_fn = metric_utils.cosine_distance
        online_dist = metric_utils.representation_distances(
            representations, target_r, distance_fn)
        target_dist = metric_utils.target_distances(target_next_r, rewards,
                                                    distance_fn,
                                                    cumulative_gamma)
        mico_loss = jnp.mean(
            jax.vmap(losses.huber_loss)(online_dist, target_dist))

        # Giving a smaller weight to the critic empirically gives better results
        sac_loss = 0.5 * critic_loss + 1.0 * policy_loss + 1.0 * alpha_loss
        combined_loss = (1. - mico_weight) * sac_loss + mico_weight * mico_loss
        return combined_loss, {
            'mico_loss': mico_loss,
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
            'alpha_loss': alpha_loss,
            'critic_value_1': jnp.mean(q_values_1),
            'critic_value_2': jnp.mean(q_values_2),
            'target_value_1': jnp.mean(target_q_values_1),
            'target_value_2': jnp.mean(target_q_values_2),
            'mean_action': jnp.mean(mean_actions, axis=0)
        }