def testSort(self, shape, dimension, arity, bdims, is_stable): rng = jtu.rand_default(self.rng()) if arity == 1: fun = partial(lax.sort, dimension=dimension) self._CheckBatching(fun, 5, bdims, (shape,) * arity, (np.float32,) * arity, rng) else: for i in range(arity): fun = lambda *args, i=i: lax.sort(args, dimension=dimension, is_stable=is_stable)[i] self._CheckBatching(fun, 5, bdims, (shape,) * arity, (np.float32,) * arity, rng)
def svd(a, full_matrices: bool = True, compute_uv: bool = True, hermitian: bool = False): a = _promote_arg_dtypes(jnp.asarray(a)) if hermitian: w, v = lax_linalg.eigh(a) s = lax.abs(v) if compute_uv: sign = lax.sign(v) idxs = lax.broadcasted_iota(np.int64, s.shape, dimension=s.ndim - 1) s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1) s = lax.rev(s, dimensions=[s.ndim - 1]) idxs = lax.rev(idxs, dimensions=[s.ndim - 1]) sign = lax.rev(sign, dimensions=[s.ndim - 1]) u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1) vh = _H(u * sign[..., None, :]) return u, s, vh else: return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim - 1]) return lax_linalg.svd(a, full_matrices, compute_uv)
def _sparsemax(x, axis): # get indices of elements in the right axis # and reshape to allow broadcasting to other dimensions idxs = jnp.arange(x.shape[axis]) + 1 idxs = reshape_to_broadcast(idxs, x.shape, axis) # calculate number of elements that belong to the support sorted_x = jnp.flip(lax.sort(x, dimension=axis), axis=axis) cum = jnp.cumsum(sorted_x, axis=axis) k = jnp.sum(jnp.where(1 + sorted_x * idxs > cum, 1, 0), axis=axis, keepdims=True) # calculate threshold and project to simplex threshold = (jnp.take_along_axis(cum, k - 1, axis=axis) - 1) / k return jnp.maximum(x - threshold, 0)
def _entmax15(x, axis): x = x / 2 # get indices of elements in the right axis # and reshape to allow broadcasting to other dimensions idxs = jnp.arange(x.shape[axis]) + 1 idxs = reshape_to_broadcast(idxs, x.shape, axis) # calculate number of elements that belong to the support sorted_x = jnp.flip(lax.sort(x, dimension=axis), axis=axis) cum_x = jnp.cumsum(sorted_x, axis=axis) cum_x_sq = jnp.cumsum(sorted_x**2, axis=axis) mean = cum_x / idxs var = cum_x_sq - (mean**2) * idxs delta = (1 - var) / idxs delta = jnp.maximum(delta, 0) # TODO: understand why we need this thresholds = mean - jnp.sqrt(delta) k = jnp.sum(jnp.where(thresholds <= sorted_x, 1, 0), axis=axis, keepdims=True) # calculate threshold and project to simplex threshold = jnp.take_along_axis(thresholds, k - 1, axis=axis) return jnp.maximum(x - threshold, 0)**2
def testSortGrad(self, shape, dtype, axis, is_stable, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable) check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2)