示例#1
0
 def testSort(self, shape, dimension, arity, bdims, is_stable):
   rng = jtu.rand_default(self.rng())
   if arity == 1:
     fun = partial(lax.sort, dimension=dimension)
     self._CheckBatching(fun, 5, bdims, (shape,) * arity, (np.float32,) * arity,
                         rng)
   else:
     for i in range(arity):
       fun = lambda *args, i=i: lax.sort(args,
                                         dimension=dimension,
                                         is_stable=is_stable)[i]
       self._CheckBatching(fun, 5, bdims, (shape,) * arity,
                           (np.float32,) * arity, rng)
示例#2
0
文件: linalg.py 项目: ahoenselaar/jax
def svd(a,
        full_matrices: bool = True,
        compute_uv: bool = True,
        hermitian: bool = False):
    a = _promote_arg_dtypes(jnp.asarray(a))
    if hermitian:
        w, v = lax_linalg.eigh(a)
        s = lax.abs(v)
        if compute_uv:
            sign = lax.sign(v)
            idxs = lax.broadcasted_iota(np.int64,
                                        s.shape,
                                        dimension=s.ndim - 1)
            s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1)
            s = lax.rev(s, dimensions=[s.ndim - 1])
            idxs = lax.rev(idxs, dimensions=[s.ndim - 1])
            sign = lax.rev(sign, dimensions=[s.ndim - 1])
            u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1)
            vh = _H(u * sign[..., None, :])
            return u, s, vh
        else:
            return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim - 1])

    return lax_linalg.svd(a, full_matrices, compute_uv)
示例#3
0
def _sparsemax(x, axis):
    # get indices of elements in the right axis
    # and reshape to allow broadcasting to other dimensions
    idxs = jnp.arange(x.shape[axis]) + 1
    idxs = reshape_to_broadcast(idxs, x.shape, axis)

    # calculate number of elements that belong to the support
    sorted_x = jnp.flip(lax.sort(x, dimension=axis), axis=axis)
    cum = jnp.cumsum(sorted_x, axis=axis)
    k = jnp.sum(jnp.where(1 + sorted_x * idxs > cum, 1, 0),
                axis=axis,
                keepdims=True)

    # calculate threshold and project to simplex
    threshold = (jnp.take_along_axis(cum, k - 1, axis=axis) - 1) / k
    return jnp.maximum(x - threshold, 0)
示例#4
0
def _entmax15(x, axis):
    x = x / 2

    # get indices of elements in the right axis
    # and reshape to allow broadcasting to other dimensions
    idxs = jnp.arange(x.shape[axis]) + 1
    idxs = reshape_to_broadcast(idxs, x.shape, axis)

    # calculate number of elements that belong to the support
    sorted_x = jnp.flip(lax.sort(x, dimension=axis), axis=axis)
    cum_x = jnp.cumsum(sorted_x, axis=axis)
    cum_x_sq = jnp.cumsum(sorted_x**2, axis=axis)
    mean = cum_x / idxs
    var = cum_x_sq - (mean**2) * idxs
    delta = (1 - var) / idxs
    delta = jnp.maximum(delta, 0)  # TODO: understand why we need this
    thresholds = mean - jnp.sqrt(delta)
    k = jnp.sum(jnp.where(thresholds <= sorted_x, 1, 0),
                axis=axis,
                keepdims=True)

    # calculate threshold and project to simplex
    threshold = jnp.take_along_axis(thresholds, k - 1, axis=axis)
    return jnp.maximum(x - threshold, 0)**2
示例#5
0
 def testSortGrad(self, shape, dtype, axis, is_stable, rng_factory):
   rng = rng_factory(self.rng())
   operand = rng(shape, dtype)
   sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable)
   check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2)