Exemple #1
0
def _quantile_regression_loss(
    dist_src: Array,
    tau_src: Array,
    dist_target: Array,
    huber_param: float = 0.,
    stop_target_gradients: bool = True,
) -> Numeric:
    """Compute (Huber) QR loss between two discrete quantile-valued distributions.

  See "Distributional Reinforcement Learning with Quantile Regression" by
  Dabney et al. (https://arxiv.org/abs/1710.10044).

  Args:
    dist_src: source probability distribution.
    tau_src: source distribution probability thresholds.
    dist_target: target probability distribution.
    huber_param: Huber loss parameter, defaults to 0 (no Huber loss).
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    Quantile regression loss.
  """
    chex.assert_rank([dist_src, tau_src, dist_target], 1)
    chex.assert_type([dist_src, tau_src, dist_target], float)

    # Calculate quantile error.
    delta = dist_target[None, :] - dist_src[:, None]
    delta_neg = (delta < 0.).astype(jnp.float32)
    delta_neg = jax.lax.select(stop_target_gradients,
                               jax.lax.stop_gradient(delta_neg), delta_neg)
    weight = jnp.abs(tau_src[:, None] - delta_neg)

    # Calculate Huber loss.
    if huber_param > 0.:
        loss = clipping.huber_loss(delta, huber_param)
    else:
        loss = jnp.abs(delta)
    loss *= weight

    # Average over target-samples dimension, sum over src-samples dimension.
    return jnp.sum(jnp.mean(loss, axis=-1))
Exemple #2
0
def _quantile_regression_loss(dist_src: ArrayLike,
                              tau_src: ArrayLike,
                              dist_target: ArrayLike,
                              huber_param: float = 0.) -> ArrayOrScalar:
    """Compute (Huber) QR loss between two discrete quantile-valued distributions.

  See "Distributional Reinforcement Learning with Quantile Regression" by
  Dabney et al. (https://arxiv.org/abs/1710.10044).

  Args:
    dist_src: source probability distribution.
    tau_src: source distribution probability thresholds.
    dist_target: target probability distribution.
    huber_param: Huber loss parameter, defaults to 0 (no Huber loss).

  Returns:
    Quantile regression loss.
  """
    base.rank_assert([dist_src, tau_src, dist_target], 1)
    base.type_assert([dist_src, tau_src, dist_target], float)

    # Calculate quantile error.
    delta = dist_target[None, :] - dist_src[:, None]
    delta_neg = (delta < 0.).astype(jnp.float32)
    delta_neg = jax.lax.stop_gradient(delta_neg)
    weight = jnp.abs(tau_src[:, None] - delta_neg)

    # Calculate Huber loss.
    if huber_param > 0.:
        loss = clipping.huber_loss(delta, huber_param)
    else:
        loss = jnp.abs(delta)
    loss *= weight

    # Average over target-samples dimension, sum over src-samples dimension.
    return jnp.sum(jnp.mean(loss, axis=-1))
Exemple #3
0
 def td_error_with_huber(x):
     return clipping.huber_loss(x, self.large_delta)