Ejemplo n.º 1
0
  def __call__(
      self,
      network: networks_lib.FeedForwardNetwork,
      params: networks_lib.Params,
      target_params: networks_lib.Params,
      batch: reverb.ReplaySample,
      key: networks_lib.PRNGKey,
  ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
    """Calculate a loss on a single batch of data."""
    transitions: types.Transition = batch.data
    keys, probs, *_ = batch.info

    # Forward pass.
    if self.stochastic_network:
      q_tm1 = network.apply(params, key, transitions.observation)
      q_t_value = network.apply(target_params, key,
                                transitions.next_observation)
      q_t_selector = network.apply(params, key, transitions.next_observation)
    else:
      q_tm1 = network.apply(params, transitions.observation)
      q_t_value = network.apply(target_params, transitions.next_observation)
      q_t_selector = network.apply(params, transitions.next_observation)

    # Cast and clip rewards.
    d_t = (transitions.discount * self.discount).astype(jnp.float32)
    r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                   self.max_abs_reward).astype(jnp.float32)

    # Compute double Q-learning n-step TD-error.
    batch_error = jax.vmap(rlax.double_q_learning)
    td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t_value,
                           q_t_selector)
    batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter)

    # Importance weighting.
    importance_weights = (1. / probs).astype(jnp.float32)
    importance_weights **= self.importance_sampling_exponent
    importance_weights /= jnp.max(importance_weights)

    # Reweight.
    loss = jnp.mean(importance_weights * batch_loss)  # []
    reverb_update = learning_lib.ReverbUpdate(
        keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64))
    extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update)
    return loss, extra
Ejemplo n.º 2
0
  def __call__(
      self,
      network: networks_lib.FeedForwardNetwork,
      params: networks_lib.Params,
      target_params: networks_lib.Params,
      batch: reverb.ReplaySample,
      key: networks_lib.PRNGKey,
  ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
    """Calculate a loss on a single batch of data."""
    del key
    transitions: types.Transition = batch.data

    # Forward pass.
    q_online_s = network.apply(params, transitions.observation)
    action_one_hot = jax.nn.one_hot(transitions.action, q_online_s.shape[-1])
    q_online_sa = jnp.sum(action_one_hot * q_online_s, axis=-1)
    q_target_s = network.apply(target_params, transitions.observation)
    q_target_next = network.apply(target_params, transitions.next_observation)

    # Cast and clip rewards.
    d_t = (transitions.discount * self.discount).astype(jnp.float32)
    r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                   self.max_abs_reward).astype(jnp.float32)

    # Munchausen term : tau * log_pi(a|s)
    munchausen_term = self.entropy_temperature * jax.nn.log_softmax(
        q_target_s / self.entropy_temperature, axis=-1)
    munchausen_term_a = jnp.sum(action_one_hot * munchausen_term, axis=-1)
    munchausen_term_a = jnp.clip(munchausen_term_a,
                                 a_min=self.clip_value_min,
                                 a_max=0.)

    # Soft Bellman operator applied to q
    next_v = self.entropy_temperature * jax.nn.logsumexp(
        q_target_next / self.entropy_temperature, axis=-1)
    target_q = jax.lax.stop_gradient(r_t + self.munchausen_coefficient *
                                     munchausen_term_a + d_t * next_v)

    batch_loss = rlax.huber_loss(target_q - q_online_sa,
                                 self.huber_loss_parameter)
    loss = jnp.mean(batch_loss)

    extra = learning_lib.LossExtra(metrics={})
    return loss, extra
Ejemplo n.º 3
0
    def __call__(
        self,
        network: networks_lib.FeedForwardNetwork,
        params: networks_lib.Params,
        target_params: networks_lib.Params,
        batch: reverb.ReplaySample,
        key: networks_lib.PRNGKey,
    ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        transitions: types.Transition = batch.data
        keys, probs, *_ = batch.info

        # Forward pass.
        _, logits_tm1, atoms_tm1 = network.apply(params,
                                                 transitions.observation)
        _, logits_t, atoms_t = network.apply(target_params,
                                             transitions.next_observation)
        q_t_selector, _, _ = network.apply(params,
                                           transitions.next_observation)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                       self.max_abs_reward).astype(jnp.float32)

        # Compute categorical double Q-learning loss.
        batch_loss_fn = jax.vmap(rlax.categorical_double_q_learning,
                                 in_axes=(None, 0, 0, 0, 0, None, 0, 0))
        batch_loss = batch_loss_fn(atoms_tm1, logits_tm1, transitions.action,
                                   r_t, d_t, atoms_t, logits_t, q_t_selector)

        # Importance weighting.
        importance_weights = (1. / probs).astype(jnp.float32)
        importance_weights **= self.importance_sampling_exponent
        importance_weights /= jnp.max(importance_weights)

        # Reweight.
        loss = jnp.mean(importance_weights * batch_loss)  # []
        reverb_update = learning_lib.ReverbUpdate(
            keys=keys, priorities=jnp.abs(batch_loss).astype(jnp.float64))
        extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update)
        return loss, extra
Ejemplo n.º 4
0
    def __call__(
        self,
        network: hk.Transformed,
        params: hk.Params,
        target_params: hk.Params,
        batch: reverb.ReplaySample,
        key: jnp.DeviceArray,
    ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        o_tm1, a_tm1, r_t, d_t, o_t = batch.data
        keys, probs, *_ = batch.info

        # Forward pass.
        q_tm1 = network.apply(params, o_tm1)
        q_t_value = network.apply(target_params, o_t)
        q_t_selector = network.apply(params, o_t)

        # Cast and clip rewards.
        d_t = (d_t * self.discount).astype(jnp.float32)
        r_t = jnp.clip(r_t, -self.max_abs_reward,
                       self.max_abs_reward).astype(jnp.float32)

        # Compute double Q-learning n-step TD-error.
        batch_error = jax.vmap(rlax.double_q_learning)
        td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value, q_t_selector)
        batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter)

        # Importance weighting.
        importance_weights = (1. / probs).astype(jnp.float32)
        importance_weights **= self.importance_sampling_exponent
        importance_weights /= jnp.max(importance_weights)

        # Reweight.
        loss = jnp.mean(importance_weights * batch_loss)  # []
        reverb_update = learning_lib.ReverbUpdate(
            keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64))
        extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update)
        return loss, extra
Ejemplo n.º 5
0
 def __call__(
     self,
     network: networks_lib.FeedForwardNetwork,
     params: networks_lib.Params,
     target_params: networks_lib.Params,
     batch: reverb.ReplaySample,
     key: networks_lib.PRNGKey,
 ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
     """Calculate a loss on a single batch of data."""
     del key
     transitions: types.Transition = batch.data
     dist_q_tm1 = network.apply(params, transitions.observation)['q_dist']
     dist_q_target_t = network.apply(target_params,
                                     transitions.next_observation)['q_dist']
     # Swap distribution and action dimension, since
     # rlax.quantile_q_learning expects it that way.
     dist_q_tm1 = jnp.swapaxes(dist_q_tm1, 1, 2)
     dist_q_target_t = jnp.swapaxes(dist_q_target_t, 1, 2)
     quantiles = ((jnp.arange(self.num_atoms, dtype=jnp.float32) + 0.5) /
                  self.num_atoms)
     batch_quantile_q_learning = jax.vmap(rlax.quantile_q_learning,
                                          in_axes=(0, None, 0, 0, 0, 0, 0,
                                                   None))
     losses = batch_quantile_q_learning(
         dist_q_tm1,
         quantiles,
         transitions.action,
         transitions.reward,
         transitions.discount,
         dist_q_target_t,  # No double Q-learning here.
         dist_q_target_t,
         self.huber_param,
     )
     loss = jnp.mean(losses)
     extra = learning_lib.LossExtra(metrics={'mean_loss': loss})
     return loss, extra