예제 #1
0
파일: fqf.py 프로젝트: winston-ds/rljax
 def _loss_cum_p(self, params_cum_p, params, state, action):
     feature = jax.lax.stop_gradient(self.net["feature"].apply(params["feature"], state))
     cum_p, cum_p_prime = self.cum_p_net.apply(params_cum_p, feature)
     quantile = get_quantile_at_action(self.net["quantile"].apply(params["quantile"], feature, cum_p[:, 1:-1]), action)
     quantile_prime = get_quantile_at_action(self.net["quantile"].apply(params["quantile"], feature, cum_p_prime), action)
     # NOTE: Proposition 1 in the paper requires F^{-1} is non-decreasing. I relax this requirements and
     # calculate gradients of taus even when F^{-1} is not non-decreasing.
     val1 = quantile - quantile_prime[:, :-1]
     sign1 = quantile > jnp.concatenate([quantile_prime[:, :1], quantile[:, :-1]], axis=1)
     val2 = quantile - quantile_prime[:, 1:]
     sign2 = quantile < jnp.concatenate([quantile[:, 1:], quantile_prime[:, -1:]], axis=1)
     grad = jnp.where(sign1, val1, -val1) + jnp.where(sign2, val2, -val2)
     grad = jax.lax.stop_gradient(grad.reshape(-1, self.num_quantiles - 1))
     return (cum_p[:, 1:-1] * grad).sum(axis=1).mean(), None
예제 #2
0
파일: fqf.py 프로젝트: winston-ds/rljax
 def _calculate_value(
     self,
     params: hk.Params,
     feature: np.ndarray,
     action: np.ndarray,
     cum_p: jnp.ndarray,
 ) -> jnp.ndarray:
     return get_quantile_at_action(self.net["quantile"].apply(params["quantile"], feature, cum_p), action)
예제 #3
0
 def _calculate_value(
     self,
     params: hk.Params,
     state: np.ndarray,
     action: np.ndarray,
     *args,
     **kwargs,
 ) -> jnp.ndarray:
     return get_quantile_at_action(self.net.apply(params, state, *args, **kwargs), action)