def _quantile_regression_loss( dist_src: Array, tau_src: Array, dist_target: Array, huber_param: float = 0., stop_target_gradients: bool = True, ) -> Numeric: """Compute (Huber) QR loss between two discrete quantile-valued distributions. See "Distributional Reinforcement Learning with Quantile Regression" by Dabney et al. (https://arxiv.org/abs/1710.10044). Args: dist_src: source probability distribution. tau_src: source distribution probability thresholds. dist_target: target probability distribution. huber_param: Huber loss parameter, defaults to 0 (no Huber loss). stop_target_gradients: bool indicating whether or not to apply stop gradient to targets. Returns: Quantile regression loss. """ chex.assert_rank([dist_src, tau_src, dist_target], 1) chex.assert_type([dist_src, tau_src, dist_target], float) # Calculate quantile error. delta = dist_target[None, :] - dist_src[:, None] delta_neg = (delta < 0.).astype(jnp.float32) delta_neg = jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(delta_neg), delta_neg) weight = jnp.abs(tau_src[:, None] - delta_neg) # Calculate Huber loss. if huber_param > 0.: loss = clipping.huber_loss(delta, huber_param) else: loss = jnp.abs(delta) loss *= weight # Average over target-samples dimension, sum over src-samples dimension. return jnp.sum(jnp.mean(loss, axis=-1))
def _quantile_regression_loss(dist_src: ArrayLike, tau_src: ArrayLike, dist_target: ArrayLike, huber_param: float = 0.) -> ArrayOrScalar: """Compute (Huber) QR loss between two discrete quantile-valued distributions. See "Distributional Reinforcement Learning with Quantile Regression" by Dabney et al. (https://arxiv.org/abs/1710.10044). Args: dist_src: source probability distribution. tau_src: source distribution probability thresholds. dist_target: target probability distribution. huber_param: Huber loss parameter, defaults to 0 (no Huber loss). Returns: Quantile regression loss. """ base.rank_assert([dist_src, tau_src, dist_target], 1) base.type_assert([dist_src, tau_src, dist_target], float) # Calculate quantile error. delta = dist_target[None, :] - dist_src[:, None] delta_neg = (delta < 0.).astype(jnp.float32) delta_neg = jax.lax.stop_gradient(delta_neg) weight = jnp.abs(tau_src[:, None] - delta_neg) # Calculate Huber loss. if huber_param > 0.: loss = clipping.huber_loss(delta, huber_param) else: loss = jnp.abs(delta) loss *= weight # Average over target-samples dimension, sum over src-samples dimension. return jnp.sum(jnp.mean(loss, axis=-1))
def td_error_with_huber(x): return clipping.huber_loss(x, self.large_delta)