예제 #1
0
파일: special.py 프로젝트: GregCT/jax
def i0(x):
    x, = _promote_args_inexact("i0", x)
    return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i0e(x))
예제 #2
0
 def variance(self):
     """ Computes circular variance of distribution """
     return jnp.broadcast_to(
         1. - lax.bessel_i1e(self.concentration) /
         lax.bessel_i0e(self.concentration), self.batch_shape)
예제 #3
0
파일: special.py 프로젝트: GregCT/jax
def i0e(x):
    x, = _promote_args_inexact("i0e", x)
    return lax.bessel_i0e(x)
예제 #4
0
 def log_prob(self, value):
     return -(jnp.log(2 * jnp.pi) + lax.bessel_i0e(self.concentration)) + (
         self.concentration * jnp.cos(value - self.loc))