Ejemplo n.º 1
0
  def __call__(
      self,
      network: networks_lib.FeedForwardNetwork,
      params: networks_lib.Params,
      target_params: networks_lib.Params,
      batch: reverb.ReplaySample,
      key: networks_lib.PRNGKey,
  ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
    """Calculate a loss on a single batch of data."""
    transitions: types.Transition = batch.data
    keys, probs, *_ = batch.info

    # Forward pass.
    if self.stochastic_network:
      q_tm1 = network.apply(params, key, transitions.observation)
      q_t_value = network.apply(target_params, key,
                                transitions.next_observation)
      q_t_selector = network.apply(params, key, transitions.next_observation)
    else:
      q_tm1 = network.apply(params, transitions.observation)
      q_t_value = network.apply(target_params, transitions.next_observation)
      q_t_selector = network.apply(params, transitions.next_observation)

    # Cast and clip rewards.
    d_t = (transitions.discount * self.discount).astype(jnp.float32)
    r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                   self.max_abs_reward).astype(jnp.float32)

    # Compute double Q-learning n-step TD-error.
    batch_error = jax.vmap(rlax.double_q_learning)
    td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t_value,
                           q_t_selector)
    batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter)

    # Importance weighting.
    importance_weights = (1. / probs).astype(jnp.float32)
    importance_weights **= self.importance_sampling_exponent
    importance_weights /= jnp.max(importance_weights)

    # Reweight.
    loss = jnp.mean(importance_weights * batch_loss)  # []
    reverb_update = learning_lib.ReverbUpdate(
        keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64))
    extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update)
    return loss, extra
Ejemplo n.º 2
0
    def __call__(
        self,
        network: networks_lib.FeedForwardNetwork,
        params: networks_lib.Params,
        target_params: networks_lib.Params,
        batch: reverb.ReplaySample,
        key: networks_lib.PRNGKey,
    ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        transitions: types.Transition = batch.data
        keys, probs, *_ = batch.info

        # Forward pass.
        _, logits_tm1, atoms_tm1 = network.apply(params,
                                                 transitions.observation)
        _, logits_t, atoms_t = network.apply(target_params,
                                             transitions.next_observation)
        q_t_selector, _, _ = network.apply(params,
                                           transitions.next_observation)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                       self.max_abs_reward).astype(jnp.float32)

        # Compute categorical double Q-learning loss.
        batch_loss_fn = jax.vmap(rlax.categorical_double_q_learning,
                                 in_axes=(None, 0, 0, 0, 0, None, 0, 0))
        batch_loss = batch_loss_fn(atoms_tm1, logits_tm1, transitions.action,
                                   r_t, d_t, atoms_t, logits_t, q_t_selector)

        # Importance weighting.
        importance_weights = (1. / probs).astype(jnp.float32)
        importance_weights **= self.importance_sampling_exponent
        importance_weights /= jnp.max(importance_weights)

        # Reweight.
        loss = jnp.mean(importance_weights * batch_loss)  # []
        reverb_update = learning_lib.ReverbUpdate(
            keys=keys, priorities=jnp.abs(batch_loss).astype(jnp.float64))
        extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update)
        return loss, extra
Ejemplo n.º 3
0
    def __call__(
        self,
        network: hk.Transformed,
        params: hk.Params,
        target_params: hk.Params,
        batch: reverb.ReplaySample,
        key: jnp.DeviceArray,
    ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        o_tm1, a_tm1, r_t, d_t, o_t = batch.data
        keys, probs, *_ = batch.info

        # Forward pass.
        q_tm1 = network.apply(params, o_tm1)
        q_t_value = network.apply(target_params, o_t)
        q_t_selector = network.apply(params, o_t)

        # Cast and clip rewards.
        d_t = (d_t * self.discount).astype(jnp.float32)
        r_t = jnp.clip(r_t, -self.max_abs_reward,
                       self.max_abs_reward).astype(jnp.float32)

        # Compute double Q-learning n-step TD-error.
        batch_error = jax.vmap(rlax.double_q_learning)
        td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value, q_t_selector)
        batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter)

        # Importance weighting.
        importance_weights = (1. / probs).astype(jnp.float32)
        importance_weights **= self.importance_sampling_exponent
        importance_weights /= jnp.max(importance_weights)

        # Reweight.
        loss = jnp.mean(importance_weights * batch_loss)  # []
        reverb_update = learning_lib.ReverbUpdate(
            keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64))
        extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update)
        return loss, extra