Exemple #1
0
def _dct2(x, axes, norm):
  axis1, axis2 = map(partial(_canonicalize_axis, num_dims=x.ndim), axes)
  N1, N2 = x.shape[axis1], x.shape[axis2]
  v = _dct_interleave(_dct_interleave(x, axis1), axis2)
  V = jnp.fft.fftn(v, axes=axes)
  k1 = lax.expand_dims(jnp.arange(N1), [a for a in range(x.ndim) if a != axis1])
  k2 = lax.expand_dims(jnp.arange(N2), [a for a in range(x.ndim) if a != axis2])
  out = _W4(N1, k1) * (_W4(N2, k2) * V + _W4(N2, -k2) * jnp.roll(jnp.flip(V, axis=axis2), shift=1, axis=axis2))
  out = 2 * out.real
  if norm == 'ortho':
    return _dct_ortho_norm(_dct_ortho_norm(out, axis1), axis2)
  return out
Exemple #2
0
def _dct_ortho_norm(out, axis):
    factor = lax.concatenate([
        lax.full((1, ), 4, out.dtype),
        lax.full((out.shape[axis] - 1, ), 2, out.dtype)
    ], 0)
    factor = lax.expand_dims(factor, [a for a in range(out.ndim) if a != axis])
    return out / lax.sqrt(factor * out.shape[axis])
Exemple #3
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = jnp.broadcast_arrays(a, b)
    dims = _reduction_dims(a, axis)
    dimadd = lambda x: lax.expand_dims(x, dims)
    amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_singletons = dimadd(amax)
    if b is None:
        out = lax.add(
            lax.log(
                lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
                           _constant_like(a, 0), lax.add, dims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = lax.reduce(lax.mul(lax.exp(lax.sub(a, amax_singletons)), b),
                            _constant_like(a, 0), lax.add, dims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (dimadd(out), dimadd(sign)) if keepdims else (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return dimadd(out) if keepdims else out
Exemple #4
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    if b is None:
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = jnp.sum(lax.mul(lax.exp(lax.sub(a, amax_with_dims)), b),
                         axis=dims,
                         keepdims=keepdims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return out
Exemple #5
0
def _pinv_jvp(rcond, primals, tangents):
    # The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
    # Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
    # Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
    # (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
    a, = primals
    a_dot, = tangents
    p = pinv(a, rcond=rcond)
    m, n = a.shape[-2:]
    # TODO(phawkins): on TPU, we would need to opt into high precision here.
    # TODO(phawkins): consider if this can be simplified in the Hermitian case.
    p_dot = -p @ a_dot @ p
    I_n = lax.expand_dims(jnp.eye(m, dtype=a.dtype), range(a.ndim - 2))
    p_dot = p_dot + p @ _H(p) @ _H(a_dot) @ (I_n - a @ p)
    I_m = lax.expand_dims(jnp.eye(n, dtype=a.dtype), range(a.ndim - 2))
    p_dot = p_dot + (I_m - p @ a) @ _H(a_dot) @ _H(p) @ p
    return p, p_dot
Exemple #6
0
def one_hot(x: Array,
            num_classes: int,
            *,
            dtype: Any = jnp.float64,
            axis: Union[int, AxisName] = -1) -> Array:
    """One-hot encodes the given indicies.

  Each index in the input ``x`` is encoded as a vector of zeros of length
  ``num_classes`` with the element at ``index`` set to one::

    >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
    DeviceArray([[1., 0., 0.],
                  [0., 1., 0.],
                  [0., 0., 1.]], dtype=float32)

  Indicies outside the range [0, num_classes) will be encoded as zeros::

    >>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
    DeviceArray([[0., 0., 0.],
                 [0., 0., 0.]], dtype=float32)

  Args:
    x: A tensor of indices.
    num_classes: Number of classes in the one-hot dimension.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).
    axis: the axis or axes along which the function should be
      computed.
  """
    num_classes = core.concrete_or_error(
        int, num_classes,
        "The error arose in jax.nn.one_hot argument `num_classes`.")
    dtype = dtypes.canonicalize_dtype(dtype)
    x = jnp.asarray(x)
    try:
        output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
    except TypeError:
        axis_size = lax.psum(1, axis)
        if num_classes != axis_size:
            raise ValueError(
                f"Expected num_classes to match the size of axis {axis}, "
                f"but {num_classes} != {axis_size}") from None
        axis_idx = lax.axis_index(axis)
        return jnp.asarray(x == axis_idx, dtype=dtype)
    axis = operator.index(axis)
    lhs = lax.expand_dims(x, (axis, ))
    rhs_shape = [1] * x.ndim
    rhs_shape.insert(output_pos_axis, num_classes)
    rhs = lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype),
                               rhs_shape, (output_pos_axis, ))
    return jnp.asarray(lhs == rhs, dtype=dtype)
Exemple #7
0
def pinv(a, rcond=None):
  # Uses same algorithm as
  # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
  a = jnp.conj(a)
  if rcond is None:
    max_rows_cols = max(a.shape[-2:])
    rcond = 10. * max_rows_cols * jnp.finfo(a.dtype).eps
  rcond = jnp.asarray(rcond)
  u, s, vh = svd(a, full_matrices=False)
  # Singular values less than or equal to ``rcond * largest_singular_value``
  # are set to zero.
  rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1))
  cutoff = rcond * jnp.amax(s, axis=-1, keepdims=True, initial=-jnp.inf)
  s = jnp.where(s > cutoff, s, jnp.inf)
  res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]))
  return lax.convert_element_type(res, a.dtype)
Exemple #8
0
def _slogdet_lu(a):
    dtype = lax.dtype(a)
    lu, pivot, _ = lax_linalg.lu(a)
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
    iota = lax.expand_dims(jnp.arange(a.shape[-1]), range(pivot.ndim - 1))
    parity = jnp.count_nonzero(pivot != iota, axis=-1)
    if jnp.iscomplexobj(a):
        sign = jnp.prod(diag / jnp.abs(diag), axis=-1)
    else:
        sign = jnp.array(1, dtype=dtype)
        parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
    sign = jnp.where(is_zero, jnp.array(0, dtype=dtype),
                     sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
    logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype),
                       jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
    return sign, jnp.real(logdet)
Exemple #9
0
def dct(x, type=2, n=None, axis=-1, norm=None):
    if type != 2:
        raise NotImplementedError('Only DCT type 2 is implemented.')

    axis = _canonicalize_axis(axis, x.ndim)
    if n is not None:
        x = lax.pad(x, jnp.array(0, x.dtype),
                    [(0, n - x.shape[axis] if a == axis else 0, 0)
                     for a in range(x.ndim)])

    N = x.shape[axis]
    v = _dct_interleave(x, axis)
    V = jnp.fft.fft(v, axis=axis)
    k = lax.expand_dims(jnp.arange(N), [a for a in range(x.ndim) if a != axis])
    out = V * _W4(N, k)
    out = 2 * out.real
    if norm == 'ortho':
        out = _dct_ortho_norm(out, axis)
    return out
Exemple #10
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    else:
        a, = _promote_args_inexact("logsumexp", a)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    # fast path if the result cannot be negative.
    if b is None and not np.issubdtype(a.dtype, np.complexfloating):
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), out, 1.0)
        sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype)
    else:
        expsub = lax.exp(lax.sub(a, amax_with_dims))
        if b is not None:
            expsub = lax.mul(expsub, b)
        sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims)

        sign = lax.stop_gradient(jnp.sign(sumexp))
        if np.issubdtype(sumexp.dtype, np.complexfloating):
            if return_sign:
                sumexp = sign * sumexp
            out = lax.add(lax.log(sumexp), amax)
        else:
            out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        if not np.issubdtype(out.dtype, np.complexfloating):
            with jax.debug_nans(False):
                out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype),
                                out)
    return out
Exemple #11
0
def _irfft_transpose(t, fft_lengths):
    # The transpose of IRFFT is the RFFT of the cotangent times a scaling
    # factor and a mask. The mask scales the cotangent for the Hermitian
    # symmetric components of the RFFT by a factor of two, since these components
    # are de-duplicated in the RFFT.
    x = fft(t, xla_client.FftType.RFFT, fft_lengths)
    n = x.shape[-1]
    is_odd = fft_lengths[-1] % 2
    full = partial(lax.full_like, t, dtype=t.dtype)
    mask = lax.concatenate([
        full(1.0, shape=(1, )),
        full(2.0, shape=(n - 2 + is_odd, )),
        full(1.0, shape=(1 - is_odd, ))
    ],
                           dimension=0)
    scale = 1 / prod(fft_lengths)
    out = scale * lax.expand_dims(mask, range(x.ndim - 1)) * x
    assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
    # Use JAX's convention for complex gradients
    # https://github.com/google/jax/issues/6223#issuecomment-807740707
    return lax.conj(out)
Exemple #12
0
def slogdet(a):
    a = _promote_arg_dtypes(jnp.asarray(a))
    dtype = lax.dtype(a)
    a_shape = jnp.shape(a)
    if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
        msg = "Argument to slogdet() must have shape [..., n, n], got {}"
        raise ValueError(msg.format(a_shape))
    lu, pivot, _ = lax_linalg.lu(a)
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
    iota = lax.expand_dims(jnp.arange(a.shape[-1]), range(pivot.ndim - 1))
    parity = jnp.count_nonzero(pivot != iota, axis=-1)
    if jnp.iscomplexobj(a):
        sign = jnp.prod(diag / jnp.abs(diag), axis=-1)
    else:
        sign = jnp.array(1, dtype=dtype)
        parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
    sign = jnp.where(is_zero, jnp.array(0, dtype=dtype),
                     sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
    logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype),
                       jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
    return sign, jnp.real(logdet)
Exemple #13
0
def _one_hot(x: Array, num_classes: int, *, dtype: Any,
             axis: Union[int, AxisName]) -> Array:
    num_classes = core.concrete_or_error(
        int, num_classes,
        "The error arose in jax.nn.one_hot argument `num_classes`.")
    dtype = dtypes.canonicalize_dtype(dtype)
    x = jnp.asarray(x)
    try:
        output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
    except TypeError:
        axis_size = lax.psum(1, axis)
        if num_classes != axis_size:
            raise ValueError(
                f"Expected num_classes to match the size of axis {axis}, "
                f"but {num_classes} != {axis_size}") from None
        axis_idx = lax.axis_index(axis)
        return jnp.asarray(x == axis_idx, dtype=dtype)
    axis = operator.index(axis)  # type: ignore[arg-type]
    lhs = lax.expand_dims(x, (axis, ))
    rhs_shape = [1] * x.ndim
    rhs_shape.insert(output_pos_axis, num_classes)
    rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis)
    return jnp.asarray(lhs == rhs, dtype=dtype)
Exemple #14
0
def _reduction(a,
               name,
               np_fun,
               op,
               init_val,
               has_identity=True,
               preproc=None,
               bool_op=None,
               upcast_f16_for_computation=False,
               axis=None,
               dtype=None,
               out=None,
               keepdims=False,
               initial=None,
               where_=None,
               parallel_reduce=None):
    bool_op = bool_op or op
    # Note: we must accept out=None as an argument, because numpy reductions delegate to
    # object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method
    # exists, passing along all its arguments.
    if out is not None:
        raise NotImplementedError(
            f"The 'out' argument to jnp.{name} is not supported.")
    _check_arraylike(name, a)
    lax_internal._check_user_dtype_supported(dtype, name)
    axis = core.concrete_or_error(None, axis,
                                  f"axis argument to jnp.{name}().")

    if initial is None and not has_identity and where_ is not None:
        raise ValueError(
            f"reduction operation {name} does not have an identity, so to use a "
            f"where mask one has to specify 'initial'")

    a = a if isinstance(a, ndarray) else _asarray(a)
    a = preproc(a) if preproc else a
    pos_dims, dims = _reduction_dims(a, axis)

    if initial is None and not has_identity:
        shape = np.shape(a)
        if not _all(core.greater_equal_dim(shape[d], 1) for d in pos_dims):
            raise ValueError(
                f"zero-size array to reduction operation {name} which has no identity"
            )

    result_dtype = dtypes.canonicalize_dtype(
        dtype or dtypes.dtype(np_fun(np.ones((), dtype=dtypes.dtype(a)))))
    if upcast_f16_for_computation and dtypes.issubdtype(
            result_dtype, np.inexact):
        computation_dtype = _upcast_f16(result_dtype)
    else:
        computation_dtype = result_dtype
    a = lax.convert_element_type(a, computation_dtype)
    op = op if computation_dtype != np.bool_ else bool_op
    # NB: in XLA, init_val must be an identity for the op, so the user-specified
    # initial value must be applied afterward.
    init_val = _reduction_init_val(a, init_val)
    if where_ is not None:
        a = _where(where_, a, init_val)
    if pos_dims is not dims:
        if parallel_reduce is None:
            raise NotImplementedError(
                f"Named reductions not implemented for jnp.{name}()")
        result = parallel_reduce(a, dims)
    else:
        result = lax.reduce(a, init_val, op, dims)
    if initial is not None:
        result = op(lax.convert_element_type(initial, a.dtype), result)
    if keepdims:
        result = lax.expand_dims(result, pos_dims)
    return lax.convert_element_type(result, dtype or result_dtype)
Exemple #15
0
def _cofactor_solve(a, b):
    """Equivalent to det(a)*solve(a, b) for nonsingular mat.

  Intermediate function used for jvp and vjp of det.
  This function borrows heavily from jax.numpy.linalg.solve and
  jax.numpy.linalg.slogdet to compute the gradient of the determinant
  in a way that is well defined even for low rank matrices.

  This function handles two different cases:
  * rank(a) == n or n-1
  * rank(a) < n-1

  For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
  Rather than computing det(a)*solve(a, b), which would return NaN, we work
  directly with the LU decomposition. If a = p @ l @ u, then
  det(a)*solve(a, b) =
  prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
  prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
  If a is rank n-1, then the lower right corner of u will be zero and the
  triangular_solve will fail.
  Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
  Then y_{n}
  x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
  x_{n} * prod_{i=1...n-1}(u_{ii})
  So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
  we can avoid the triangular_solve failing.
  To correctly compute the rest of y_{i} for i != n, we simply multiply
  x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.

  For the second case, a check is done on the matrix to see if `solve`
  returns NaN or Inf, and gives a matrix of zeros as a result, as the
  gradient of the determinant of a matrix with rank less than n-1 is 0.
  This will still return the correct value for rank n-1 matrices, as the check
  is applied *after* the lower right corner of u has been updated.

  Args:
    a: A square matrix or batch of matrices, possibly singular.
    b: A matrix, or batch of matrices of the same dimension as a.

  Returns:
    det(a) and cofactor(a)^T*b, aka adjugate(a)*b
  """
    a = _promote_arg_dtypes(jnp.asarray(a))
    b = _promote_arg_dtypes(jnp.asarray(b))
    a_shape = jnp.shape(a)
    b_shape = jnp.shape(b)
    a_ndims = len(a_shape)
    if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
            and b_shape[-2:] == a_shape[-2:]):
        msg = ("The arguments to _cofactor_solve must have shapes "
               "a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
        raise ValueError(msg.format(a_shape, b_shape))
    if a_shape[-1] == 1:
        return a[..., 0, 0], b
    # lu contains u in the upper triangular matrix and l in the strict lower
    # triangular matrix.
    # The diagonal of l is set to ones without loss of generality.
    lu, pivots, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
    x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
    lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:])
    # Compute (partial) determinant, ignoring last diagonal of LU
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    iota = lax.expand_dims(jnp.arange(a_shape[-1]), range(pivots.ndim - 1))
    parity = jnp.count_nonzero(pivots != iota, axis=-1)
    sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype)
    # partial_det[:, -1] contains the full determinant and
    # partial_det[:, -2] contains det(u) / u_{nn}.
    partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
    lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2])
    permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1], ))
    iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, )))
    # filter out any matrices that are not full rank
    d = jnp.ones(x.shape[:-1], x.dtype)
    d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
    d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1)
    d = jnp.tile(d[..., None, None], d.ndim * (1, ) + x.shape[-2:])
    x = jnp.where(d, jnp.zeros_like(x), x)  # first filter
    x = x[iotas[:-1] + (permutation, slice(None))]
    x = lax_linalg.triangular_solve(lu,
                                    x,
                                    left_side=True,
                                    lower=True,
                                    unit_diagonal=True)
    x = jnp.concatenate(
        (x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]),
        axis=-2)
    x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
    x = jnp.where(d, jnp.zeros_like(x), x)  # second filter

    return partial_det[..., -1], x
Exemple #16
0
 def expand_dims(self, dimensions: Sequence[int]):
     # follows lax.expand_dims, not jnp.expand_dims, so dimensions is a sequence
     ndim_out = self.ndim + len(set(dimensions))
     dimensions = [canonicalize_axis(d, ndim_out) for d in dimensions]
     return PRNGKeyArray(self.impl, lax.expand_dims(self._keys, dimensions))
Exemple #17
0
def _unique(ar,
            axis,
            return_index=False,
            return_inverse=False,
            return_counts=False,
            size=None,
            fill_value=None,
            return_true_size=False):
    """
  Find the unique elements of an array along a particular axis.
  """
    if ar.shape[axis] == 0 and size and fill_value is None:
        raise ValueError(
            "jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified"
        )

    aux, mask, perm = _unique_sorted_mask(ar, axis)
    if size is None:
        ind = core.concrete_or_error(
            None, mask, "The error arose in jnp.unique(). " + UNIQUE_SIZE_HINT)
    else:
        ind = nonzero(mask, size=size)[0]
    result = aux[ind] if aux.size else aux
    if fill_value is not None:
        fill_value = asarray(fill_value, dtype=result.dtype)
    if size is not None and fill_value is not None:
        if result.shape[0]:
            valid = lax.expand_dims(
                arange(size) < mask.sum(), tuple(range(1, result.ndim)))
            result = where(valid, result, fill_value)
        else:
            result = full_like(result,
                               fill_value,
                               shape=(size, *result.shape[1:]))
    result = moveaxis(result, 0, axis)

    ret = (result, )
    if return_index:
        if aux.size:
            ret += (perm[ind], )
        else:
            ret += (perm, )
    if return_inverse:
        if aux.size:
            imask = cumsum(mask) - 1
            inv_idx = zeros(mask.shape,
                            dtype=dtypes.canonicalize_dtype(dtypes.int_))
            inv_idx = inv_idx.at[perm].set(imask)
        else:
            inv_idx = zeros(ar.shape[axis], dtype=int)
        ret += (inv_idx, )
    if return_counts:
        if aux.size:
            if size is None:
                idx = append(nonzero(mask)[0], mask.size)
            else:
                idx = nonzero(mask, size=size + 1)[0]
                idx = idx.at[1:].set(where(idx[1:], idx[1:], mask.size))
            ret += (diff(idx), )
        elif ar.shape[axis]:
            ret += (array([ar.shape[axis]],
                          dtype=dtypes.canonicalize_dtype(dtypes.int_)), )
        else:
            ret += (empty(0, dtype=int), )
    if return_true_size:
        # Useful for internal uses of unique().
        ret += (mask.sum(), )
    return ret[0] if len(ret) == 1 else ret