示例#1
0
def rint(x):
    _check_arraylike('rint', x)
    dtype = dtypes.dtype(x)
    if dtypes.issubdtype(dtype, np.integer):
        return lax.convert_element_type(x, dtypes.float_)
    if dtypes.issubdtype(dtype, np.complexfloating):
        return lax.complex(rint(lax.real(x)), rint(lax.imag(x)))
    return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
示例#2
0
def sign(x):
    _check_arraylike('sign', x)
    dtype = dtypes.dtype(x)
    if dtypes.issubdtype(dtype, np.complexfloating):
        re = lax.real(x)
        return lax.complex(lax.sign(_where(re != 0, re, lax.imag(x))),
                           _constant_like(re, 0))
    return lax.sign(x)
示例#3
0
文件: special.py 项目: jbampton/jax
def _sph_harm(m: jnp.ndarray, n: jnp.ndarray, theta: jnp.ndarray,
              phi: jnp.ndarray, n_max: int) -> jnp.ndarray:
    """Computes the spherical harmonics."""

    cos_colatitude = jnp.cos(phi)

    legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
    legendre_val = legendre.at[abs(m), n, jnp.arange(len(n))].get(mode="clip")

    angle = abs(m) * theta
    vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle))
    harmonics = lax.complex(legendre_val * jnp.real(vandermonde),
                            legendre_val * jnp.imag(vandermonde))

    # Negative order.
    harmonics = jnp.where(m < 0, (-1.0)**abs(m) * jnp.conjugate(harmonics),
                          harmonics)

    return harmonics
示例#4
0
def logaddexp(x1, x2):
  x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
  amax = lax.max(x1, x2)
  if dtypes.issubdtype(x1.dtype, np.floating):
    delta = lax.sub(x1, x2)
    return lax.select(lax_internal._isnan(delta),
                      lax.add(x1, x2),  # NaNs or infinities of the same sign.
                      lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
  else:
    delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
    out = lax.add(amax, lax.log1p(lax.exp(delta)))
    return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))