def _loss_cum_p(self, params_cum_p, params, state, action): feature = jax.lax.stop_gradient(self.net["feature"].apply(params["feature"], state)) cum_p, cum_p_prime = self.cum_p_net.apply(params_cum_p, feature) quantile = get_quantile_at_action(self.net["quantile"].apply(params["quantile"], feature, cum_p[:, 1:-1]), action) quantile_prime = get_quantile_at_action(self.net["quantile"].apply(params["quantile"], feature, cum_p_prime), action) # NOTE: Proposition 1 in the paper requires F^{-1} is non-decreasing. I relax this requirements and # calculate gradients of taus even when F^{-1} is not non-decreasing. val1 = quantile - quantile_prime[:, :-1] sign1 = quantile > jnp.concatenate([quantile_prime[:, :1], quantile[:, :-1]], axis=1) val2 = quantile - quantile_prime[:, 1:] sign2 = quantile < jnp.concatenate([quantile[:, 1:], quantile_prime[:, -1:]], axis=1) grad = jnp.where(sign1, val1, -val1) + jnp.where(sign2, val2, -val2) grad = jax.lax.stop_gradient(grad.reshape(-1, self.num_quantiles - 1)) return (cum_p[:, 1:-1] * grad).sum(axis=1).mean(), None
def _calculate_value( self, params: hk.Params, feature: np.ndarray, action: np.ndarray, cum_p: jnp.ndarray, ) -> jnp.ndarray: return get_quantile_at_action(self.net["quantile"].apply(params["quantile"], feature, cum_p), action)
def _calculate_value( self, params: hk.Params, state: np.ndarray, action: np.ndarray, *args, **kwargs, ) -> jnp.ndarray: return get_quantile_at_action(self.net.apply(params, state, *args, **kwargs), action)