def rint(x): _check_arraylike('rint', x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.integer): return lax.convert_element_type(x, dtypes.float_) if dtypes.issubdtype(dtype, np.complexfloating): return lax.complex(rint(lax.real(x)), rint(lax.imag(x))) return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
def sign(x): _check_arraylike('sign', x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.complexfloating): re = lax.real(x) return lax.complex(lax.sign(_where(re != 0, re, lax.imag(x))), _constant_like(re, 0)) return lax.sign(x)
def _sph_harm(m: jnp.ndarray, n: jnp.ndarray, theta: jnp.ndarray, phi: jnp.ndarray, n_max: int) -> jnp.ndarray: """Computes the spherical harmonics.""" cos_colatitude = jnp.cos(phi) legendre = _gen_associated_legendre(n_max, cos_colatitude, True) legendre_val = legendre.at[abs(m), n, jnp.arange(len(n))].get(mode="clip") angle = abs(m) * theta vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle)) harmonics = lax.complex(legendre_val * jnp.real(vandermonde), legendre_val * jnp.imag(vandermonde)) # Negative order. harmonics = jnp.where(m < 0, (-1.0)**abs(m) * jnp.conjugate(harmonics), harmonics) return harmonics
def logaddexp(x1, x2): x1, x2 = _promote_args_inexact("logaddexp", x1, x2) amax = lax.max(x1, x2) if dtypes.issubdtype(x1.dtype, np.floating): delta = lax.sub(x1, x2) return lax.select(lax_internal._isnan(delta), lax.add(x1, x2), # NaNs or infinities of the same sign. lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) else: delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) out = lax.add(amax, lax.log1p(lax.exp(delta))) return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))