Example #1
0
def slogdet(a):
    a = _promote_arg_dtypes(jnp.asarray(a))
    dtype = lax.dtype(a)
    a_shape = jnp.shape(a)
    if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
        msg = "Argument to slogdet() must have shape [..., n, n], got {}"
        raise ValueError(msg.format(a_shape))
    lu, pivot, _ = lax_linalg.lu(a)
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
    parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1)
    if jnp.iscomplexobj(a):
        sign = jnp.prod(diag / jnp.abs(diag), axis=-1)
    else:
        sign = jnp.array(1, dtype=dtype)
        parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
    sign = jnp.where(is_zero, jnp.array(0, dtype=dtype),
                     sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
    logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype),
                       jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
    return sign, jnp.real(logdet)
Example #2
0
def _map_coordinates(input, coordinates, order, mode, cval):
  input = jnp.asarray(input)
  coordinates = [jnp.asarray(c) for c in coordinates]
  cval = jnp.asarray(cval, input.dtype)

  if len(coordinates) != input.ndim:
    raise ValueError('coordinates must be a sequence of length input.ndim, but '
                     '{} != {}'.format(len(coordinates), input.ndim))

  index_fixer = _INDEX_FIXERS.get(mode)
  if index_fixer is None:
    raise NotImplementedError(
        'jax.scipy.ndimage.map_coordinates does not yet support mode {}. '
        'Currently supported modes are {}.'.format(mode, set(_INDEX_FIXERS)))

  if mode == 'constant':
    is_valid = lambda index, size: (0 <= index) & (index < size)
  else:
    is_valid = lambda index, size: True

  if order == 0:
    interp_fun = _nearest_indices_and_weights
  elif order == 1:
    interp_fun = _linear_indices_and_weights
  else:
    raise NotImplementedError(
        'jax.scipy.ndimage.map_coordinates currently requires order<=1')

  valid_1d_interpolations = []
  for coordinate, size in zip(coordinates, input.shape):
    interp_nodes = interp_fun(coordinate)
    valid_interp = []
    for index, weight in interp_nodes:
      fixed_index = index_fixer(index, size)
      valid = is_valid(index, size)
      valid_interp.append((fixed_index, valid, weight))
    valid_1d_interpolations.append(valid_interp)

  outputs = []
  for items in itertools.product(*valid_1d_interpolations):
    indices, validities, weights = zip(*items)
    if all(valid is True for valid in validities):
      # fast path
      contribution = input[indices]
    else:
      all_valid = functools.reduce(operator.and_, validities)
      contribution = jnp.where(all_valid, input[indices], cval)
    outputs.append(_nonempty_prod(weights) * contribution)
  result = _nonempty_sum(outputs)
  if jnp.issubdtype(input.dtype, jnp.integer):
    result = _round_half_away_from_zero(result)
  return result.astype(input.dtype)
Example #3
0
File: eigh.py Project: cloudhan/jax
def _mask(x, dims, alternative=0):
  """Masks `x` up to the dynamic shape `dims`.

  Replaces values outside those dimensions with `alternative`. `alternative` is
  broadcast with `x`.
  """
  assert jnp.ndim(x) == len(dims)
  mask = None
  for i, d in enumerate(dims):
    if d is not None:
      mask_dim_i = lax.broadcasted_iota(jnp.int32, x.shape, i) < d
      mask = mask_dim_i if mask is None else (mask & mask_dim_i)
  return x if mask is None else jnp.where(mask, x, alternative)
Example #4
0
def logpdf(x, df, loc=0, scale=1):
    x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale)
    one = _constant_like(x, 1)
    two = _constant_like(x, 2)
    y = lax.div(lax.sub(x, loc), scale)
    df_on_two = lax.div(df, two)

    kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))

    nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))

    log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
    return where(lax.lt(x, loc), -inf, log_probs)
Example #5
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    else:
        a, = _promote_args_inexact("logsumexp", a)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    # fast path if the result cannot be negative.
    if b is None and not np.issubdtype(a.dtype, np.complexfloating):
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        expsub = lax.exp(lax.sub(a, amax_with_dims))
        if b is not None:
            expsub = lax.mul(expsub, b)
        sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims)

        sign = lax.stop_gradient(jnp.sign(sumexp))
        if np.issubdtype(sumexp.dtype, np.complexfloating):
            if return_sign:
                sumexp = sign * sumexp
            out = lax.add(lax.log(sumexp), amax)
        else:
            out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        if not np.issubdtype(out.dtype, np.complexfloating):
            out = jnp.where(sign < 0, np.nan, out)
    return out
Example #6
0
def _lstsq(a, b, rcond, *, numpy_resid=False):
    # TODO: add lstsq to lax_linalg and implement this function via those wrappers.
    # TODO: add custom jvp rule for more robust lstsq differentiation
    a, b = _promote_arg_dtypes(a, b)
    if a.shape[0] != b.shape[0]:
        raise ValueError("Leading dimensions of input arrays must match")
    b_orig_ndim = b.ndim
    if b_orig_ndim == 1:
        b = b[:, None]
    if a.ndim != 2:
        raise TypeError(
            f"{a.ndim}-dimensional array given. Array must be two-dimensional")
    if b.ndim != 2:
        raise TypeError(
            f"{b.ndim}-dimensional array given. Array must be one or two-dimensional"
        )
    m, n = a.shape
    dtype = a.dtype
    if rcond is None:
        rcond = jnp.finfo(dtype).eps * max(n, m)
    else:
        rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
    u, s, vt = svd(a, full_matrices=False)
    mask = s >= rcond * s[0]
    rank = mask.sum()
    safe_s = jnp.where(mask, s, 1)
    s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
    uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
    x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
    # Numpy returns empty residuals in some cases. To allow compilation, we
    # default to returning full residuals in all cases.
    if numpy_resid and (rank < n or m <= n):
        resid = jnp.asarray([])
    else:
        b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST)
        resid = norm(b - b_estimate, axis=0)**2
    if b_orig_ndim == 1:
        x = x.ravel()
    return x, resid, rank, s
Example #7
0
    def _evaluate_linear(self, indices, norm_distances):
        # slice for broadcasting over trailing dimensions in self.values
        vslice = (slice(None), ) + (None, ) * (self.values.ndim - len(indices))

        # find relevant values
        # each i and i+1 represents a edge
        edges = product(*[[i, i + 1] for i in indices])
        values = asarray(0.)
        for edge_indices in edges:
            weight = asarray(1.)
            for ei, i, yi in zip(edge_indices, indices, norm_distances):
                weight *= where(ei == i, 1 - yi, yi)
            values += self.values[edge_indices] * weight[vslice]
        return values
Example #8
0
File: special.py Project: zizai/jax
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
  if b is not None:
    a, b = jnp.broadcast_arrays(a, b)
  pos_dims, dims = _reduction_dims(a, axis)
  amax = jnp.max(a, axis=dims, keepdims=keepdims)
  amax = lax.stop_gradient(lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
  amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
  if b is None:
    out = lax.add(lax.log(jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                                  axis=dims, keepdims=keepdims)),
                  amax)
    sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
    sign = jnp.where(out == -np.inf, 0.0, sign)
  else:
    sumexp = jnp.sum(lax.mul(lax.exp(lax.sub(a, amax_with_dims)), b),
                     axis=dims, keepdims=keepdims)
    sign = lax.stop_gradient(lax.sign(sumexp))
    out = lax.add(lax.log(lax.abs(sumexp)), amax)
  if return_sign:
    return (out, sign)
  if b is not None:
    out = jnp.where(sign < 0, np.nan, out)
  return out
Example #9
0
def _expn1(n, x):
    # exponential integral En
    _c = _constant_like
    x = jnp.array(x)
    MACHEP = jnp.finfo(x.dtype).eps

    zero = _c(x, 0.0)
    one = _c(x, 1.0)
    psi = -jnp.euler_gamma - jnp.log(x)
    psi = lax.fori_loop(_c(n, 1), n, lambda i, psi: psi + one / i, psi)
    n1 = jnp.where(n == _c(n, 1), one + one, n)
    init = dict(
        x=x,
        z=-x,
        xk=zero,
        yk=one,
        pk=one - n,
        ans=jnp.where(n == _c(n, 1), zero, one / (one - n1)),
        t=jnp.inf,
    )

    def body(d):
        d["xk"] += one
        d["yk"] *= d["z"] / d["xk"]
        d["pk"] += one
        d["ans"] += jnp.where(d["pk"] != zero, d["yk"] / d["pk"], zero)
        d["t"] = jnp.where(d["ans"] != zero, abs(d["yk"] / d["ans"]), one)
        return d

    def cond(d):
        return (d["x"] > _c(d["x"], 0.0)) & (d["t"] > MACHEP)

    d = lax.while_loop(cond, body, init)
    t = n
    r = n - _c(n, 1)
    return d["z"]**r * psi / jnp.exp(gammaln(t)) - d["ans"]
Example #10
0
def pinv(a, rcond=None):
    # Uses same algorithm as
    # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
    a = jnp.conj(a)
    if rcond is None:
        max_rows_cols = max(a.shape[-2:])
        rcond = 10. * max_rows_cols * jnp.finfo(a.dtype).eps
    rcond = jnp.asarray(rcond)
    u, s, vh = svd(a, full_matrices=False)
    # Singular values less than or equal to ``rcond * largest_singular_value``
    # are set to zero.
    cutoff = rcond[..., jnp.newaxis] * jnp.amax(
        s, axis=-1, keepdims=True, initial=-jnp.inf)
    s = jnp.where(s > cutoff, s, jnp.inf)
    res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]))
    return lax.convert_element_type(res, a.dtype)
Example #11
0
def _slogdet_qr(a):
  # Implementation of slogdet using QR decomposition. One reason we might prefer
  # QR decomposition is that it is more amenable to a fast batched
  # implementation on TPU because of the lack of row pivoting.
  if jnp.issubdtype(lax.dtype(a), jnp.complexfloating):
    raise NotImplementedError("slogdet method='qr' not implemented for complex "
                              "inputs")
  n = a.shape[-1]
  a, taus = lax_linalg.geqrf(a)
  # The determinant of a triangular matrix is the product of its diagonal
  # elements. We are working in log space, so we compute the magnitude as the
  # the trace of the log-absolute values, and we compute the sign separately.
  log_abs_det = jnp.trace(jnp.log(jnp.abs(a)), axis1=-2, axis2=-1)
  sign_diag = jnp.prod(jnp.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1)
  # The determinant of a Householder reflector is -1. So whenever we actually
  # made a reflection (tau != 0), multiply the result by -1.
  sign_taus = jnp.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype)
  return sign_diag * sign_taus, log_abs_det
Example #12
0
def logpdf(x, alpha):
  x, alpha = _promote_dtypes_inexact(x, alpha)
  if alpha.ndim != 1:
    raise ValueError(
      f"`alpha` must be one-dimensional; got alpha.shape={alpha.shape}"
    )
  if x.shape[0] not in (alpha.shape[0], alpha.shape[0] - 1):
    raise ValueError(
      "`x` must have either the same number of entries as `alpha` "
      f"or one entry fewer; got x.shape={x.shape}, alpha.shape={alpha.shape}"
    )
  one = lax._const(x, 1)
  if x.shape[0] != alpha.shape[0]:
    x = jnp.concatenate([x, lax.sub(one, x.sum(0, keepdims=True))], axis=0)
  normalize_term = jnp.sum(gammaln(alpha)) - gammaln(jnp.sum(alpha))
  if x.ndim > 1:
    alpha = lax.broadcast_in_dim(alpha, alpha.shape + (1,) * (x.ndim - 1), (0,))
  log_probs = lax.sub(jnp.sum(xlogy(lax.sub(alpha, one), x), axis=0), normalize_term)
  return jnp.where(_is_simplex(x), log_probs, -jnp.inf)
Example #13
0
def _sph_harm(m: jnp.ndarray, n: jnp.ndarray, theta: jnp.ndarray,
              phi: jnp.ndarray, n_max: int) -> jnp.ndarray:
    """Computes the spherical harmonics."""

    cos_colatitude = jnp.cos(phi)

    legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
    legendre_val = legendre[abs(m), n, jnp.arange(len(n))]

    angle = abs(m) * theta
    vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle))
    harmonics = lax.complex(legendre_val * jnp.real(vandermonde),
                            legendre_val * jnp.imag(vandermonde))

    # Negative order.
    harmonics = jnp.where(m < 0, (-1.0)**abs(m) * jnp.conjugate(harmonics),
                          harmonics)

    return harmonics
Example #14
0
 def body(d):
     x = d["x"]
     d["k"] += _c(d["k"], 1)
     k = d["k"]
     odd = k % _c(k, 2) == _c(k, 1)
     yk = jnp.where(odd, one, x)
     xk = jnp.where(odd, n + (k - _c(k, 1)) / _c(k, 2), k / _c(k, 2))
     pk = d["pkm1"] * yk + d["pkm2"] * xk
     qk = d["qkm1"] * yk + d["qkm2"] * xk
     nz = qk != zero
     d["r"] = r = jnp.where(nz, pk / qk, d["r"])
     d["t"] = jnp.where(nz, abs((d["ans"] - r) / r), one)
     d["ans"] = jnp.where(nz, r, d["ans"])
     d["pkm2"] = d["pkm1"]
     d["pkm1"] = pk
     d["qkm2"] = d["qkm1"]
     d["qkm1"] = qk
     is_big = abs(pk) > BIG
     for s in "pq":
         for i in "12":
             key = s + "km" + i
             d[key] = jnp.where(is_big, d[key] / BIG, d[key])
     return d
Example #15
0
def _cofactor_solve(a, b):
    """Equivalent to det(a)*solve(a, b) for nonsingular mat.

  Intermediate function used for jvp and vjp of det.
  This function borrows heavily from jax.numpy.linalg.solve and
  jax.numpy.linalg.slogdet to compute the gradient of the determinant
  in a way that is well defined even for low rank matrices.

  This function handles two different cases:
  * rank(a) == n or n-1
  * rank(a) < n-1

  For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
  Rather than computing det(a)*solve(a, b), which would return NaN, we work
  directly with the LU decomposition. If a = p @ l @ u, then
  det(a)*solve(a, b) =
  prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
  prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
  If a is rank n-1, then the lower right corner of u will be zero and the
  triangular_solve will fail.
  Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
  Then y_{n}
  x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
  x_{n} * prod_{i=1...n-1}(u_{ii})
  So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
  we can avoid the triangular_solve failing.
  To correctly compute the rest of y_{i} for i != n, we simply multiply
  x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.

  For the second case, a check is done on the matrix to see if `solve`
  returns NaN or Inf, and gives a matrix of zeros as a result, as the
  gradient of the determinant of a matrix with rank less than n-1 is 0.
  This will still return the correct value for rank n-1 matrices, as the check
  is applied *after* the lower right corner of u has been updated.

  Args:
    a: A square matrix or batch of matrices, possibly singular.
    b: A matrix, or batch of matrices of the same dimension as a.

  Returns:
    det(a) and cofactor(a)^T*b, aka adjugate(a)*b
  """
    a = _promote_arg_dtypes(jnp.asarray(a))
    b = _promote_arg_dtypes(jnp.asarray(b))
    a_shape = jnp.shape(a)
    b_shape = jnp.shape(b)
    a_ndims = len(a_shape)
    if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
            and b_shape[-2:] == a_shape[-2:]):
        msg = ("The arguments to _cofactor_solve must have shapes "
               "a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
        raise ValueError(msg.format(a_shape, b_shape))
    if a_shape[-1] == 1:
        return a[..., 0, 0], b
    # lu contains u in the upper triangular matrix and l in the strict lower
    # triangular matrix.
    # The diagonal of l is set to ones without loss of generality.
    lu, pivots, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
    x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
    lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:])
    # Compute (partial) determinant, ignoring last diagonal of LU
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1)
    sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype)
    # partial_det[:, -1] contains the full determinant and
    # partial_det[:, -2] contains det(u) / u_{nn}.
    partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
    lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2])
    permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1], ))
    iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, )))
    # filter out any matrices that are not full rank
    d = jnp.ones(x.shape[:-1], x.dtype)
    d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
    d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1)
    d = jnp.tile(d[..., None, None], d.ndim * (1, ) + x.shape[-2:])
    x = jnp.where(d, jnp.zeros_like(x), x)  # first filter
    x = x[iotas[:-1] + (permutation, slice(None))]
    x = lax_linalg.triangular_solve(lu,
                                    x,
                                    left_side=True,
                                    lower=True,
                                    unit_diagonal=True)
    x = jnp.concatenate(
        (x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]),
        axis=-2)
    x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
    x = jnp.where(d, jnp.zeros_like(x), x)  # second filter

    return partial_det[..., -1], x
Example #16
0
def log_ndtr(x, series_order=3):
    r"""Log Normal distribution function.

  For details of the Normal distribution function see `ndtr`.

  This function calculates :math:`\log(\mathrm{ndtr}(x))` by either calling
  :math:`\log(\mathrm{ndtr}(x))` or using an asymptotic series. Specifically:

  - For `x > upper_segment`, use the approximation `-ndtr(-x)` based on
    :math:`\log(1-x) \approx -x, x \ll 1`.
  - For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique
    and take a log.
  - For `x <= lower_segment`, we use the series approximation of `erf` to compute
    the log CDF directly.

  The `lower_segment` is set based on the precision of the input:

  .. math::
    \begin{align}
    \mathit{lower\_segment} =&
      \ \begin{cases}
        -20 &  x.\mathrm{dtype}=\mathit{float64} \\
        -10 &  x.\mathrm{dtype}=\mathit{float32} \\
        \end{cases} \\
    \mathit{upper\_segment} =&
      \ \begin{cases}
        8&  x.\mathrm{dtype}=\mathit{float64} \\
        5&  x.\mathrm{dtype}=\mathit{float32} \\
        \end{cases}
    \end{align}


  When `x < lower_segment`, the `ndtr` asymptotic series approximation is:

  .. math::
    \begin{align}
     \mathrm{ndtr}(x) =&\  \mathit{scale} * (1 + \mathit{sum}) + R_N \\
     \mathit{scale}   =&\  \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\
     \mathit{sum}     =&\  \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\
     R_N     =&\  O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3})
    \end{align}

  where :math:`(2n-1)!! = (2n-1) (2n-3) (2n-5) ...  (3) (1)` is a
  `double-factorial
  <https://en.wikipedia.org/wiki/Double_factorial>`_ operator.


  Args:
    x: an array of type `float32`, `float64`.
    series_order: Positive Python integer. Maximum depth to
      evaluate the asymptotic expansion. This is the `N` above.

  Returns:
    an array with `dtype=x.dtype`.

  Raises:
    TypeError: if `x.dtype` is not handled.
    TypeError: if `series_order` is a not Python `integer.`
    ValueError:  if `series_order` is not in `[0, 30]`.
  """
    if not isinstance(series_order, int):
        raise TypeError("series_order must be a Python integer.")
    if series_order < 0:
        raise ValueError("series_order must be non-negative.")
    if series_order > 30:
        raise ValueError("series_order must be <= 30.")

    x = jnp.asarray(x)
    dtype = lax.dtype(x)

    if dtype == jnp.float64:
        lower_segment = _LOGNDTR_FLOAT64_LOWER
        upper_segment = _LOGNDTR_FLOAT64_UPPER
    elif dtype == jnp.float32:
        lower_segment = _LOGNDTR_FLOAT32_LOWER
        upper_segment = _LOGNDTR_FLOAT32_UPPER
    else:
        raise TypeError("x.dtype={} is not supported.".format(np.dtype(dtype)))

    # The basic idea here was ported from:
    #   https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
    # We copy the main idea, with a few changes
    # * For x >> 1, and X ~ Normal(0, 1),
    #     Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x],
    #     which extends the range of validity of this function.
    # * We use one fixed series_order for all of 'x', rather than adaptive.
    # * Our docstring properly reflects that this is an asymptotic series, not a
    #   Taylor series. We also provided a correct bound on the remainder.
    # * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when
    #   x=0. This happens even though the branch is unchosen because when x=0
    #   the gradient of a select involves the calculation 1*dy+0*(-inf)=nan
    #   regardless of whether dy is finite. Note that the minimum is a NOP if
    #   the branch is chosen.
    return jnp.where(
        lax.gt(x, upper_segment),
        -_ndtr(-x),  # log(1-x) ~= -x, x << 1
        jnp.where(lax.gt(x, lower_segment),
                  lax.log(_ndtr(lax.max(x, lower_segment))),
                  _log_ndtr_lower(lax.min(x, lower_segment), series_order)))
Example #17
0
def _ndtri(p):
    """Implements ndtri core logic."""

    # Constants used in piece-wise rational approximations. Taken from the cephes
    # library:
    # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
    p0 = list(
        reversed([
            -5.99633501014107895267E1, 9.80010754185999661536E1,
            -5.66762857469070293439E1, 1.39312609387279679503E1,
            -1.23916583867381258016E0
        ]))
    q0 = list(
        reversed([
            1.0, 1.95448858338141759834E0, 4.67627912898881538453E0,
            8.63602421390890590575E1, -2.25462687854119370527E2,
            2.00260212380060660359E2, -8.20372256168333339912E1,
            1.59056225126211695515E1, -1.18331621121330003142E0
        ]))
    p1 = list(
        reversed([
            4.05544892305962419923E0, 3.15251094599893866154E1,
            5.71628192246421288162E1, 4.40805073893200834700E1,
            1.46849561928858024014E1, 2.18663306850790267539E0,
            -1.40256079171354495875E-1, -3.50424626827848203418E-2,
            -8.57456785154685413611E-4
        ]))
    q1 = list(
        reversed([
            1.0, 1.57799883256466749731E1, 4.53907635128879210584E1,
            4.13172038254672030440E1, 1.50425385692907503408E1,
            2.50464946208309415979E0, -1.42182922854787788574E-1,
            -3.80806407691578277194E-2, -9.33259480895457427372E-4
        ]))
    p2 = list(
        reversed([
            3.23774891776946035970E0, 6.91522889068984211695E0,
            3.93881025292474443415E0, 1.33303460815807542389E0,
            2.01485389549179081538E-1, 1.23716634817820021358E-2,
            3.01581553508235416007E-4, 2.65806974686737550832E-6,
            6.23974539184983293730E-9
        ]))
    q2 = list(
        reversed([
            1.0, 6.02427039364742014255E0, 3.67983563856160859403E0,
            1.37702099489081330271E0, 2.16236993594496635890E-1,
            1.34204006088543189037E-2, 3.28014464682127739104E-4,
            2.89247864745380683936E-6, 6.79019408009981274425E-9
        ]))

    dtype = lax.dtype(p).type
    shape = jnp.shape(p)

    def _create_polynomial(var, coeffs):
        """Compute n_th order polynomial via Horner's method."""
        coeffs = np.array(coeffs, dtype)
        if not coeffs.size:
            return jnp.zeros_like(var)
        return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var

    maybe_complement_p = jnp.where(p > dtype(-np.expm1(-2.)), dtype(1.) - p, p)
    # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs
    # later on. The result from the computation when p == 0 is not used so any
    # number that doesn't result in NaNs is fine.
    sanitized_mcp = jnp.where(maybe_complement_p <= dtype(0.),
                              jnp.full(shape, dtype(0.5)), maybe_complement_p)

    # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2).
    w = sanitized_mcp - dtype(0.5)
    ww = lax.square(w)
    x_for_big_p = w + w * ww * (_create_polynomial(ww, p0) /
                                _create_polynomial(ww, q0))
    x_for_big_p *= -dtype(np.sqrt(2. * np.pi))

    # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z),
    # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different
    # arrays based on whether p < exp(-32).
    z = lax.sqrt(dtype(-2.) * lax.log(sanitized_mcp))
    first_term = z - lax.log(z) / z
    second_term_small_p = (_create_polynomial(dtype(1.) / z, p2) /
                           _create_polynomial(dtype(1.) / z, q2) / z)
    second_term_otherwise = (_create_polynomial(dtype(1.) / z, p1) /
                             _create_polynomial(dtype(1.) / z, q1) / z)
    x_for_small_p = first_term - second_term_small_p
    x_otherwise = first_term - second_term_otherwise

    x = jnp.where(sanitized_mcp > dtype(np.exp(-2.)), x_for_big_p,
                  jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise))

    x = jnp.where(p > dtype(1. - np.exp(-2.)), x, -x)
    infinity = jnp.full(shape, dtype(np.inf))
    x_nan_replaced = jnp.where(p <= dtype(0.0), -infinity,
                               jnp.where(p >= dtype(1.0), infinity, x))
    return x_nan_replaced
Example #18
0
def _polygamma(n, x):
    dtype = lax.dtype(n).type
    n_plus = n + dtype(1)
    sign = dtype(1) - (n_plus % dtype(2)) * dtype(2)
    return jnp.where(n == 0, digamma(x),
                     sign * jnp.exp(gammaln(n_plus)) * zeta(n_plus, x))
Example #19
0
def xlogy(x, y):
    x, y = _promote_args_inexact("xlogy", x, y)
    x_ok = x != 0.
    safe_x = jnp.where(x_ok, x, 1.)
    safe_y = jnp.where(x_ok, y, 1.)
    return jnp.where(x_ok, lax.mul(safe_x, lax.log(safe_y)), jnp.zeros_like(x))
Example #20
0
def _unique(ar,
            axis,
            return_index=False,
            return_inverse=False,
            return_counts=False,
            size=None,
            fill_value=None,
            return_true_size=False):
    """
  Find the unique elements of an array along a particular axis.
  """
    if ar.shape[axis] == 0 and size and fill_value is None:
        raise ValueError(
            "jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified"
        )

    aux, mask, perm = _unique_sorted_mask(ar, axis)
    if size is None:
        ind = core.concrete_or_error(
            None, mask, "The error arose in jnp.unique(). " + UNIQUE_SIZE_HINT)
    else:
        ind = nonzero(mask, size=size)[0]
    result = aux[ind] if aux.size else aux
    if fill_value is not None:
        fill_value = asarray(fill_value, dtype=result.dtype)
    if size is not None and fill_value is not None:
        if result.shape[0]:
            valid = lax.expand_dims(
                arange(size) < mask.sum(), tuple(range(1, result.ndim)))
            result = where(valid, result, fill_value)
        else:
            result = full_like(result,
                               fill_value,
                               shape=(size, *result.shape[1:]))
    result = moveaxis(result, 0, axis)

    ret = (result, )
    if return_index:
        if aux.size:
            ret += (perm[ind], )
        else:
            ret += (perm, )
    if return_inverse:
        if aux.size:
            imask = cumsum(mask) - 1
            inv_idx = zeros(mask.shape,
                            dtype=dtypes.canonicalize_dtype(dtypes.int_))
            inv_idx = inv_idx.at[perm].set(imask)
        else:
            inv_idx = zeros(ar.shape[axis], dtype=int)
        ret += (inv_idx, )
    if return_counts:
        if aux.size:
            if size is None:
                idx = append(nonzero(mask)[0], mask.size)
            else:
                idx = nonzero(mask, size=size + 1)[0]
                idx = idx.at[1:].set(where(idx[1:], idx[1:], mask.size))
            ret += (diff(idx), )
        elif ar.shape[axis]:
            ret += (array([ar.shape[axis]],
                          dtype=dtypes.canonicalize_dtype(dtypes.int_)), )
        else:
            ret += (empty(0, dtype=int), )
    if return_true_size:
        # Useful for internal uses of unique().
        ret += (mask.sum(), )
    return ret[0] if len(ret) == 1 else ret
Example #21
0
 def sturm_step0():
     q = alpha[0] - x
     count = jnp.where(q < 0, ones, zeros)
     q = jnp.where(alpha[0] == x, alpha0_perturbation, q)
     return q, count
Example #22
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("expon.logpdf", x, loc, scale)
    log_scale = lax.log(scale)
    linear_term = lax.div(lax.sub(x, loc), scale)
    log_probs = lax.neg(lax.add(linear_term, log_scale))
    return where(lax.lt(x, loc), -inf, log_probs)
Example #23
0
def logpmf(k, mu, loc=0):
    k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc)
    zero = jnp._constant_like(k, 0)
    x = lax.sub(k, loc)
    log_probs = xlogy(x, mu) - gammaln(x + 1) - mu
    return jnp.where(lax.lt(x, zero), -jnp.inf, log_probs)
Example #24
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("uniform.logpdf", x, loc, scale)
    log_probs = lax.neg(lax.log(scale))
    return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)),
                 -inf, log_probs)
Example #25
0
 def _evaluate_nearest(self, indices, norm_distances):
     idx_res = [
         where(yi <= .5, i, i + 1) for i, yi in zip(indices, norm_distances)
     ]
     return self.values[tuple(idx_res)]
Example #26
0
 def sturm_step(i, q, count):
     q = alpha[i] - beta_sq[i - 1] / q - x
     count = jnp.where(q <= pivmin, count + 1, count)
     q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q)
     return q, count
Example #27
0
def cdf(k, mu, loc=0):
    k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc)
    zero = jnp._constant_like(k, 0)
    x = lax.sub(k, loc)
    p = gammaincc(jnp.floor(1 + x), mu)
    return jnp.where(lax.lt(x, zero), zero, p)
Example #28
0
def istft(Zxx,
          fs=1.0,
          window='hann',
          nperseg=None,
          noverlap=None,
          nfft=None,
          input_onesided=True,
          boundary=True,
          time_axis=-1,
          freq_axis=-2):
    # Input validation
    _check_arraylike("istft", Zxx)
    if Zxx.ndim < 2:
        raise ValueError('Input stft must be at least 2d!')
    freq_axis = canonicalize_axis(freq_axis, Zxx.ndim)
    time_axis = canonicalize_axis(time_axis, Zxx.ndim)
    if freq_axis == time_axis:
        raise ValueError('Must specify differing time and frequency axes!')

    Zxx = jnp.asarray(Zxx,
                      dtype=jax.dtypes.canonicalize_dtype(
                          np.result_type(Zxx, np.complex64)))

    n_default = (2 * (Zxx.shape[freq_axis] - 1)
                 if input_onesided else Zxx.shape[freq_axis])

    nperseg = jax.core.concrete_or_error(int, nperseg or n_default,
                                         "nperseg: segment length of STFT")
    if nperseg < 1:
        raise ValueError('nperseg must be a positive integer')

    if nfft is None:
        nfft = n_default
        if input_onesided and nperseg == n_default + 1:
            nfft += 1  # Odd nperseg, no FFT padding
    else:
        nfft = jax.core.concrete_or_error(int, nfft, "nfft of STFT")
    if nfft < nperseg:
        raise ValueError(
            f'FFT length ({nfft}) must be longer than nperseg ({nperseg}).')

    noverlap = jax.core.concrete_or_error(int, noverlap or nperseg // 2,
                                          "noverlap of STFT")
    if noverlap >= nperseg:
        raise ValueError('noverlap must be less than nperseg.')
    nstep = nperseg - noverlap

    # Rearrange axes if necessary
    if time_axis != Zxx.ndim - 1 or freq_axis != Zxx.ndim - 2:
        outer_idxs = tuple(idx for idx in range(Zxx.ndim)
                           if idx not in {time_axis, freq_axis})
        Zxx = jnp.transpose(Zxx, outer_idxs + (freq_axis, time_axis))

    # Perform IFFT
    ifunc = jax.numpy.fft.irfft if input_onesided else jax.numpy.fft.ifft
    # xsubs: [..., T, N], N is the number of frames, T is the frame length.
    xsubs = ifunc(Zxx, axis=-2, n=nfft)[..., :nperseg, :]

    # Get window as array
    if isinstance(window, (str, tuple)):
        win = osp_signal.get_window(window, nperseg)
        win = jnp.asarray(win)
    else:
        win = jnp.asarray(window)
        if len(win.shape) != 1:
            raise ValueError('window must be 1-D')
        if win.shape[0] != nperseg:
            raise ValueError('window must have length of {0}'.format(nperseg))
    win = win.astype(xsubs.dtype)

    xsubs *= win.sum()  # This takes care of the 'spectrum' scaling

    # make win broadcastable over xsubs
    win = win.reshape((1, ) * (xsubs.ndim - 2) + win.shape + (1, ))
    x = _overlap_and_add((xsubs * win).swapaxes(-2, -1), nstep)
    win_squared = jnp.repeat((win * win), xsubs.shape[-1], axis=-1)
    norm = _overlap_and_add(win_squared.swapaxes(-2, -1), nstep)

    # Remove extension points
    if boundary:
        x = x[..., nperseg // 2:-(nperseg // 2)]
        norm = norm[..., nperseg // 2:-(nperseg // 2)]
    x /= jnp.where(norm > 1e-10, norm, 1.0)

    # Put axes back
    if x.ndim > 1:
        if time_axis != Zxx.ndim - 1:
            if freq_axis < time_axis:
                time_axis -= 1
            x = jnp.moveaxis(x, -1, time_axis)

    time = jnp.arange(x.shape[0]) / fs
    return time, x