Example #1
0
def _unique_sorted_mask(ar, axis):
    aux = moveaxis(ar, axis, 0)
    if np.issubdtype(aux.dtype, np.complexfloating):
        # Work around issue in sorting of complex numbers with Nan only in the
        # imaginary component. This can be removed if sorting in this situation
        # is fixed to match numpy.
        aux = where(isnan(aux), _lax_const(aux, np.nan), aux)
    size, *out_shape = aux.shape
    if _prod(out_shape) == 0:
        size = 1
        perm = zeros(1, dtype=int)
    else:
        perm = lexsort(aux.reshape(size, _prod(out_shape)).T[::-1])
    aux = aux[perm]
    if aux.size:
        if dtypes.issubdtype(aux.dtype, np.inexact):
            # This is appropriate for both float and complex due to the documented behavior of np.unique:
            # See https://github.com/numpy/numpy/blob/v1.22.0/numpy/lib/arraysetops.py#L212-L220
            neq = lambda x, y: lax.ne(x, y) & ~(isnan(x) & isnan(y))
        else:
            neq = lax.ne
        mask = ones(size, dtype=bool).at[1:].set(
            any(neq(aux[1:], aux[:-1]), tuple(range(1, aux.ndim))))
    else:
        mask = zeros(size, dtype=bool)
    return aux, mask, perm
Example #2
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = jnp.broadcast_arrays(a, b)
    dims = _reduction_dims(a, axis)
    dimadd = lambda x: lax.expand_dims(x, dims)
    amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_singletons = dimadd(amax)
    if b is None:
        out = lax.add(
            lax.log(
                lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
                           _constant_like(a, 0), lax.add, dims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = lax.reduce(lax.mul(lax.exp(lax.sub(a, amax_singletons)), b),
                            _constant_like(a, 0), lax.add, dims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (dimadd(out), dimadd(sign)) if keepdims else (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return dimadd(out) if keepdims else out
Example #3
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    if b is None:
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = jnp.sum(lax.mul(lax.exp(lax.sub(a, amax_with_dims)), b),
                         axis=dims,
                         keepdims=keepdims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return out
Example #4
0
def cond(x, p=None):
    _assertNoEmpty2d(x)
    if p in (None, 2):
        s = la.svd(x, compute_uv=False)
        return s[..., 0] / s[..., -1]
    elif p == -2:
        s = la.svd(x, compute_uv=False)
        r = s[..., -1] / s[..., 0]
    else:
        _assertRankAtLeast2(x)
        _assertNdSquareness(x)
        invx = la.inv(x)
        r = la.norm(x, ord=p, axis=(-2, -1)) * la.norm(
            invx, ord=p, axis=(-2, -1))

    # Convert nans to infs unless the original array had nan entries
    orig_nan_check = jnp.full_like(r, ~jnp.isnan(r).any())
    nan_mask = jnp.logical_and(jnp.isnan(r), ~jnp.isnan(x).any(axis=(-2, -1)))
    r = jnp.where(orig_nan_check, jnp.where(nan_mask, jnp.inf, r), r)
    return r
Example #5
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    else:
        a, = _promote_args_inexact("logsumexp", a)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    # fast path if the result cannot be negative.
    if b is None and not np.issubdtype(a.dtype, np.complexfloating):
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), out, 1.0)
        sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype)
    else:
        expsub = lax.exp(lax.sub(a, amax_with_dims))
        if b is not None:
            expsub = lax.mul(expsub, b)
        sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims)

        sign = lax.stop_gradient(jnp.sign(sumexp))
        if np.issubdtype(sumexp.dtype, np.complexfloating):
            if return_sign:
                sumexp = sign * sumexp
            out = lax.add(lax.log(sumexp), amax)
        else:
            out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        if not np.issubdtype(out.dtype, np.complexfloating):
            # Use jnp.array(nan) to avoid false positives in debug_nans
            # (see https://github.com/google/jax/issues/7634)
            out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out)
    return out
Example #6
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    else:
        a, = _promote_args_inexact("logsumexp", a)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    # fast path if the result cannot be negative.
    if b is None and not np.issubdtype(a.dtype, np.complexfloating):
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        expsub = lax.exp(lax.sub(a, amax_with_dims))
        if b is not None:
            expsub = lax.mul(expsub, b)
        sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims)

        sign = lax.stop_gradient(jnp.sign(sumexp))
        if np.issubdtype(sumexp.dtype, np.complexfloating):
            if return_sign:
                sumexp = sign * sumexp
            out = lax.add(lax.log(sumexp), amax)
        else:
            out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        if not np.issubdtype(out.dtype, np.complexfloating):
            out = jnp.where(sign < 0, np.nan, out)
    return out
Example #7
0
def _cofactor_solve(a, b):
    """Equivalent to det(a)*solve(a, b) for nonsingular mat.

  Intermediate function used for jvp and vjp of det.
  This function borrows heavily from jax.numpy.linalg.solve and
  jax.numpy.linalg.slogdet to compute the gradient of the determinant
  in a way that is well defined even for low rank matrices.

  This function handles two different cases:
  * rank(a) == n or n-1
  * rank(a) < n-1

  For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
  Rather than computing det(a)*solve(a, b), which would return NaN, we work
  directly with the LU decomposition. If a = p @ l @ u, then
  det(a)*solve(a, b) =
  prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
  prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
  If a is rank n-1, then the lower right corner of u will be zero and the
  triangular_solve will fail.
  Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
  Then y_{n}
  x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
  x_{n} * prod_{i=1...n-1}(u_{ii})
  So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
  we can avoid the triangular_solve failing.
  To correctly compute the rest of y_{i} for i != n, we simply multiply
  x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.

  For the second case, a check is done on the matrix to see if `solve`
  returns NaN or Inf, and gives a matrix of zeros as a result, as the
  gradient of the determinant of a matrix with rank less than n-1 is 0.
  This will still return the correct value for rank n-1 matrices, as the check
  is applied *after* the lower right corner of u has been updated.

  Args:
    a: A square matrix or batch of matrices, possibly singular.
    b: A matrix, or batch of matrices of the same dimension as a.

  Returns:
    det(a) and cofactor(a)^T*b, aka adjugate(a)*b
  """
    a = _promote_arg_dtypes(jnp.asarray(a))
    b = _promote_arg_dtypes(jnp.asarray(b))
    a_shape = jnp.shape(a)
    b_shape = jnp.shape(b)
    a_ndims = len(a_shape)
    if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
            and b_shape[-2:] == a_shape[-2:]):
        msg = ("The arguments to _cofactor_solve must have shapes "
               "a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
        raise ValueError(msg.format(a_shape, b_shape))
    if a_shape[-1] == 1:
        return a[..., 0, 0], b
    # lu contains u in the upper triangular matrix and l in the strict lower
    # triangular matrix.
    # The diagonal of l is set to ones without loss of generality.
    lu, pivots, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
    x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
    lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:])
    # Compute (partial) determinant, ignoring last diagonal of LU
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1)
    sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype)
    # partial_det[:, -1] contains the full determinant and
    # partial_det[:, -2] contains det(u) / u_{nn}.
    partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
    lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2])
    permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1], ))
    iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, )))
    # filter out any matrices that are not full rank
    d = jnp.ones(x.shape[:-1], x.dtype)
    d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
    d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1)
    d = jnp.tile(d[..., None, None], d.ndim * (1, ) + x.shape[-2:])
    x = jnp.where(d, jnp.zeros_like(x), x)  # first filter
    x = x[iotas[:-1] + (permutation, slice(None))]
    x = lax_linalg.triangular_solve(lu,
                                    x,
                                    left_side=True,
                                    lower=True,
                                    unit_diagonal=True)
    x = jnp.concatenate(
        (x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]),
        axis=-2)
    x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
    x = jnp.where(d, jnp.zeros_like(x), x)  # second filter

    return partial_det[..., -1], x