Beispiel #1
0
def _calc_P_Q(A):
    A = jnp.asarray(A)
    if A.ndim != 2 or A.shape[0] != A.shape[1]:
        raise ValueError('expected A to be a square matrix')
    A_L1 = np_linalg.norm(A, 1)
    n_squarings = 0
    if A.dtype == 'float64' or A.dtype == 'complex128':
        maxnorm = 5.371920351148152
        n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
        A = A / 2**n_squarings.astype(A.dtype)
        conds = jnp.array([
            1.495585217958292e-002, 2.539398330063230e-001,
            9.504178996162932e-001, 2.097847961257068e+000
        ],
                          dtype=A_L1.dtype)
        idx = jnp.digitize(A_L1, conds)
        U, V = lax.switch(idx, [_pade3, _pade5, _pade7, _pade9, _pade13], A)
    elif A.dtype == 'float32' or A.dtype == 'complex64':
        maxnorm = 3.925724783138660
        n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
        A = A / 2**n_squarings.astype(A.dtype)
        conds = jnp.array([4.258730016922831e-001, 1.880152677804762e+000],
                          dtype=A_L1.dtype)
        idx = jnp.digitize(A_L1, conds)
        U, V = lax.switch(idx, [_pade3, _pade5, _pade7], A)
    else:
        raise TypeError(f"A.dtype={A.dtype} is not supported.")
    P = U + V  # p_m(A) : numerator
    Q = -U + V  # q_m(A) : denominator
    return P, Q, n_squarings
Beispiel #2
0
def _calc_P_Q(A):
  A = jnp.asarray(A)
  if A.ndim != 2 or A.shape[0] != A.shape[1]:
    raise ValueError('expected A to be a square matrix')
  A_L1 = np_linalg.norm(A,1)
  n_squarings = 0
  if A.dtype == 'float64' or A.dtype == 'complex128':
   U3, V3 = _pade3(A)
   U5, V5 = _pade5(A)
   U7, V7 = _pade7(A)
   U9, V9 = _pade9(A)
   maxnorm = 5.371920351148152
   n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
   A = A / 2**n_squarings
   U13, V13 = _pade13(A)
   conds=jnp.array([1.495585217958292e-002, 2.539398330063230e-001,
                    9.504178996162932e-001, 2.097847961257068e+000])
   U = jnp.select((A_L1<conds), (U3, U5, U7, U9), U13)
   V = jnp.select((A_L1<conds), (V3, V5, V7, V9), V13)
  elif A.dtype == 'float32' or A.dtype == 'complex64':
    U3,V3 = _pade3(A)
    U5,V5 = _pade5(A)
    maxnorm = 3.925724783138660
    n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
    A = A / 2**n_squarings
    U7,V7 = _pade7(A)
    conds=jnp.array([4.258730016922831e-001, 1.880152677804762e+000])
    U = jnp.select((A_L1<conds), (U3, U5), U7)
    V = jnp.select((A_L1<conds), (V3, V5), V7)
  else:
    raise TypeError("A.dtype={} is not supported.".format(A.dtype))
  P = U + V  # p_m(A) : numerator
  Q = -U + V # q_m(A) : denominator
  return P, Q, n_squarings
Beispiel #3
0
    def _update_T_Z(m, T, Z):
        mu = np_linalg.eigvals(lax.dynamic_slice(T, (m - 1, m - 1),
                                                 (2, 2))) - T[m, m]
        r = np_linalg.norm(jnp.array([mu[0], T[m, m - 1]])).astype(T.dtype)
        c = mu[0] / r
        s = T[m, m - 1] / r
        G = jnp.array([[c.conj(), s], [-s, c]], dtype=T.dtype)

        # T[m-1:m+1, m-1:] = G @ T[m-1:m+1, m-1:]
        T_rows = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=0)
        col_mask = jnp.arange(N) >= m - 1
        G_dot_T_zeroed_cols = G @ jnp.where(col_mask, T_rows, 0)
        T_rows_new = jnp.where(~col_mask, T_rows, G_dot_T_zeroed_cols)
        T = lax.dynamic_update_slice_in_dim(T, T_rows_new, m - 1, axis=0)

        # T[:m+1, m-1:m+1] = T[:m+1, m-1:m+1] @ G.conj().T
        T_cols = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=1)
        row_mask = jnp.arange(N)[:, jnp.newaxis] < m + 1
        T_zeroed_rows_dot_GH = jnp.where(row_mask, T_cols, 0) @ G.conj().T
        T_cols_new = jnp.where(~row_mask, T_cols, T_zeroed_rows_dot_GH)
        T = lax.dynamic_update_slice_in_dim(T, T_cols_new, m - 1, axis=1)

        # Z[:, m-1:m+1] = Z[:, m-1:m+1] @ G.conj().T
        Z_cols = lax.dynamic_slice_in_dim(Z, m - 1, 2, axis=1)
        Z = lax.dynamic_update_slice_in_dim(Z,
                                            Z_cols @ G.conj().T,
                                            m - 1,
                                            axis=1)
        return T, Z
Beispiel #4
0
def roots(p, *, strip_zeros=True):
    # ported from https://github.com/numpy/numpy/blob/v1.17.0/numpy/lib/polynomial.py#L168-L251
    p = atleast_1d(p)
    if p.ndim != 1:
        raise ValueError("Input must be a rank-1 array.")

    # strip_zeros=False is unsafe because leading zeros aren't removed
    if not strip_zeros:
        if p.size > 1:
            return _roots_no_zeros(p)
        else:
            return array([])

    if all(p == 0):
        return array([])

    # factor out trivial roots
    start, end = _nonzero_range(p)
    # number of trailing zeros = number of roots at 0
    trailing_zeros = p.size - end

    # strip leading and trailing zeros
    p = p[start:end]

    if p.size < 2:
        return zeros(trailing_zeros, p.dtype)
    else:
        roots = _roots_no_zeros(p)
        # combine roots and zero roots
        roots = hstack((roots, zeros(trailing_zeros, p.dtype)))
        return roots
Beispiel #5
0
    def body(k, state):
        pivot, perm, a = state
        m_idx = jnp.arange(m)
        n_idx = jnp.arange(n)

        if jnp.issubdtype(a.dtype, jnp.complexfloating):
            t = a[:, k]
            magnitude = jnp.abs(jnp.real(t)) + jnp.abs(jnp.imag(t))
        else:
            magnitude = jnp.abs(a[:, k])
        i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf))
        pivot = ops.index_update(pivot, ops.index[k], i)

        a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ])

        perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ])

        # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
        x = a[k, k]
        a = ops.index_update(a, ops.index[:, k],
                             jnp.where(m_idx > k, a[:, k] / x, a[:, k]))

        # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:])
        a = a - jnp.where(
            (m_idx[:, None] > k) & (n_idx > k), jnp.outer(a[:, k], a[k, :]),
            jnp.array(0, dtype=a.dtype))
        return pivot, perm, a
Beispiel #6
0
def _roots_no_zeros(p):
    # build companion matrix and find its eigenvalues (the roots)
    if p.size < 2:
        return array([], dtype=dtypes._to_complex_dtype(p.dtype))
    A = diag(ones((p.size - 2, ), p.dtype), -1)
    A = A.at[0, :].set(-p[1:] / p[0])
    return linalg.eigvals(A)
Beispiel #7
0
def _slice(operand, start_indices, dynamic_slice_sizes, static_slice_sizes,
           fill_value=0):
  """Similar to lax.dynamic_slice, but handles arrays with dynamic sizes.

  Returns fill_value instead of clamping start_indices for those elements that
  would overflow the side of the array.

  Args:
    operand: the array to slice
    start_indices: the offset of the start of the slice
    dynamic_slice_sizes: the true (unpadded) size of the slice
    static_slice_sizes: the padded size of the slice, which must be known at
      compile time. The static size must be larger than the dynamic size.
    fill_value: value with which to replace masked-out elements.
  Returns:
    An array with static shape `static_slice_sizes`, padded from its true
    (dynamic) size `dynamic_slice_sizes`.
  """
  # We must pad the input array so the dynamic_slice is guaranteed to fall
  # entirely in bounds.
  padded = lax.pad(operand,
                   jnp.array(0, operand.dtype),
                   [(0, d, 0) for d in static_slice_sizes])
  out = lax.dynamic_slice(padded, tuple(jnp.int32(i) for i in start_indices),
                          static_slice_sizes)
  return _mask(out, dynamic_slice_sizes, fill_value)
Beispiel #8
0
def _fft_core(func_name, fft_type, a, s, axes, norm):
    full_name = "jax.numpy.fft." + func_name

    if s is not None:
        s = tuple(map(operator.index, s))
        if np.any(np.less(s, 0)):
            raise ValueError("Shape should be non-negative.")

    if s is not None and axes is not None and len(s) != len(axes):
        # Same error as numpy.
        raise ValueError("Shape and axes have different lengths.")

    orig_axes = axes
    if axes is None:
        if s is None:
            axes = range(a.ndim)
        else:
            axes = range(a.ndim - len(s), a.ndim)

    if len(axes) != len(set(axes)):
        raise ValueError(
            f"{full_name} does not support repeated axes. Got axes {axes}.")

    if len(axes) > 3:
        # XLA does not support FFTs over more than 3 dimensions
        raise ValueError("%s only supports 1D, 2D, and 3D FFTs. "
                         "Got axes %s with input rank %s." %
                         (full_name, orig_axes, a.ndim))

    # XLA only supports FFTs over the innermost axes, so rearrange if necessary.
    if orig_axes is not None:
        axes = tuple(range(a.ndim - len(axes), a.ndim))
        a = jnp.moveaxis(a, orig_axes, axes)

    if s is not None:
        a = jnp.asarray(a)
        in_s = list(a.shape)
        for axis, x in safe_zip(axes, s):
            in_s[axis] = x
        if fft_type == xla_client.FftType.IRFFT:
            in_s[-1] = (in_s[-1] // 2 + 1)
        # Cropping
        a = a[tuple(map(slice, in_s))]
        # Padding
        a = jnp.pad(a, [(0, x - y) for x, y in zip(in_s, a.shape)])
    else:
        if fft_type == xla_client.FftType.IRFFT:
            s = [a.shape[axis] for axis in axes[:-1]]
            if axes:
                s += [max(0, 2 * (a.shape[axes[-1]] - 1))]
        else:
            s = [a.shape[axis] for axis in axes]
    transformed = lax.fft(a, fft_type, tuple(s))
    transformed *= _fft_norm(jnp.array(s, dtype=transformed.dtype), func_name,
                             norm)

    if orig_axes is not None:
        transformed = jnp.moveaxis(transformed, axes, orig_axes)
    return transformed
Beispiel #9
0
  def __getitem__(self, key):
    if not isinstance(key, tuple):
      key = (key,)

    params = [self.axis, self.ndmin, self.trans1d, -1]

    if isinstance(key[0], str):
      # split off the directive
      directive, *key = key  # pytype: disable=bad-unpacking
      # check two special cases: matrix directives
      if directive == "r":
        params[-1] = 0
      elif directive == "c":
        params[-1] = 1
      else:
        vec = directive.split(",")
        k = len(vec)
        if k < 4:
          vec += params[k:]
        else:
          # ignore everything after the first three comma-separated ints
          vec = vec[:3] + params[-1]
        try:
          params = list(map(int, vec))
        except ValueError as err:
          raise ValueError(
            f"could not understand directive {directive!r}"
          ) from err

    axis, ndmin, trans1d, matrix = params

    output = []
    for item in key:
      if isinstance(item, slice):
        newobj = _make_1d_grid_from_slice(item, op_name=self.op_name)
      elif isinstance(item, str):
        raise ValueError("string directive must be placed at the beginning")
      else:
        newobj = item

      newobj = array(newobj, copy=False, ndmin=ndmin)

      if trans1d != -1 and ndmin - np.ndim(item) > 0:
        shape_obj = list(range(ndmin))
        # Calculate number of left shifts, with overflow protection by mod
        num_lshifts = ndmin - abs(ndmin + trans1d + 1) % ndmin
        shape_obj = tuple(shape_obj[num_lshifts:] + shape_obj[:num_lshifts])

        newobj = transpose(newobj, shape_obj)

      output.append(newobj)

    res = concatenate(tuple(output), axis=axis)

    if matrix != -1 and res.ndim == 1:
      # insert 2nd dim at axis 0 or 1
      res = expand_dims(res, matrix)

    return res
Beispiel #10
0
def _eval_expint_k(A, B, x):
    # helper function for all subsequent intervals
    A, B = [jnp.array(U, dtype=x.dtype) for U in [A, B]]
    one = _constant_like(x, 1.0)
    w = one / x
    f = jnp.polyval(A, w) / jnp.polyval(B, w)
    f = w * f + one
    return jnp.exp(x) * w * f
Beispiel #11
0
def _all_gather_via_psum(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size):
  index = axis_index(axis_name)
  if axis_index_groups is not None:
    indices = np.array(axis_index_groups).flatten()
    axis_index_to_group_index = indices.argsort() % len(axis_index_groups[0])
    index = lax_numpy.array(axis_index_to_group_index)[index]
  outs = tree_util.tree_map(partial(_expand, all_gather_dimension, axis_size, index), x)
  return psum(outs, axis_name, axis_index_groups=axis_index_groups)
Beispiel #12
0
def _slogdet_lu(a):
    dtype = lax.dtype(a)
    lu, pivot, _ = lax_linalg.lu(a)
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
    iota = lax.expand_dims(jnp.arange(a.shape[-1]), range(pivot.ndim - 1))
    parity = jnp.count_nonzero(pivot != iota, axis=-1)
    if jnp.iscomplexobj(a):
        sign = jnp.prod(diag / jnp.abs(diag), axis=-1)
    else:
        sign = jnp.array(1, dtype=dtype)
        parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
    sign = jnp.where(is_zero, jnp.array(0, dtype=dtype),
                     sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
    logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype),
                       jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
    return sign, jnp.real(logdet)
Beispiel #13
0
def setxor1d(ar1, ar2, assume_unique=False):
    _check_arraylike("setxor1d", ar1, ar2)
    ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()")
    ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()")

    ar1 = ravel(ar1)
    ar2 = ravel(ar2)

    if not assume_unique:
        ar1 = unique(ar1)
        ar2 = unique(ar2)

    aux = concatenate((ar1, ar2))
    if aux.size == 0:
        return aux

    aux = sort(aux)
    flag = concatenate((array([True]), aux[1:] != aux[:-1], array([True])))
    return aux[flag[1:] & flag[:-1]]
Beispiel #14
0
def slogdet(a):
    a = _promote_arg_dtypes(jnp.asarray(a))
    dtype = lax.dtype(a)
    a_shape = jnp.shape(a)
    if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
        msg = "Argument to slogdet() must have shape [..., n, n], got {}"
        raise ValueError(msg.format(a_shape))
    lu, pivot, _ = lax_linalg.lu(a)
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
    parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1)
    if jnp.iscomplexobj(a):
        sign = jnp.prod(diag / jnp.abs(diag), axis=-1)
    else:
        sign = jnp.array(1, dtype=dtype)
        parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
    sign = jnp.where(is_zero, jnp.array(0, dtype=dtype),
                     sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
    logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype),
                       jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
    return sign, jnp.real(logdet)
Beispiel #15
0
def _triage_segments(window, nperseg, input_length):
    """
  Parses window and nperseg arguments for spectrogram and _spectral_helper.
  This is a helper function, not meant to be called externally.
  Parameters
  ----------
  window : string, tuple, or ndarray
      If window is specified by a string or tuple and nperseg is not
      specified, nperseg is set to the default of 256 and returns a window of
      that length.
      If instead the window is array_like and nperseg is not specified, then
      nperseg is set to the length of the window. A ValueError is raised if
      the user supplies both an array_like window and a value for nperseg but
      nperseg does not equal the length of the window.
  nperseg : int
      Length of each segment
  input_length: int
      Length of input signal, i.e. x.shape[-1]. Used to test for errors.
  Returns
  -------
  win : ndarray
      window. If function was called with string or tuple than this will hold
      the actual array used as a window.
  nperseg : int
      Length of each segment. If window is str or tuple, nperseg is set to
      256. If window is array_like, nperseg is set to the length of the
      6
      window.
  """
    # parse window; if array like, then set nperseg = win.shape
    if isinstance(window, (str, tuple)):
        # if nperseg not specified
        if nperseg is None:
            nperseg = 256  # then change to default
        if nperseg > input_length:
            warnings.warn(f'nperseg = {nperseg} is greater than input length '
                          f' = {input_length}, using nperseg = {nperseg}')
            nperseg = input_length
        win = jnp.array(osp_signal.get_window(window, nperseg))
    else:
        win = jnp.asarray(window)
        if len(win.shape) != 1:
            raise ValueError('window must be 1-D')
        if input_length < win.shape[-1]:
            raise ValueError('window is longer than input signal')
        if nperseg is None:
            nperseg = win.shape[0]
        elif nperseg is not None:
            if nperseg != win.shape[0]:
                raise ValueError("value specified for nperseg is different"
                                 " from length of window")
    return win, nperseg
Beispiel #16
0
def _lu(a, permute_l):
    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    lu, pivots, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    m, n = jnp.shape(a)
    p = jnp.real(jnp.array(permutation == jnp.arange(m)[:, None], dtype=dtype))
    k = min(m, n)
    l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
    u = jnp.triu(lu)[:k, :]
    if permute_l:
        return jnp.matmul(p, l), u
    else:
        return p, l, u
Beispiel #17
0
def pinv(a, rcond=None):
  # Uses same algorithm as
  # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
  a = jnp.conj(a)
  if rcond is None:
    max_rows_cols = max(a.shape[-2:])
    rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(a.dtype).eps)
  rcond = jnp.asarray(rcond)
  u, s, vh = svd(a, full_matrices=False)
  # Singular values less than or equal to ``rcond * largest_singular_value``
  # are set to zero.
  rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1))
  cutoff = rcond * jnp.amax(s, axis=-1, keepdims=True, initial=-jnp.inf)
  s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
  res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]))
  return lax.convert_element_type(res, a.dtype)
Beispiel #18
0
def _lu(a, permute_l):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    lu, _, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    m, n = jnp.shape(a)
    p = jnp.real(
        jnp.array(permutation[None, :] == jnp.arange(
            m, dtype=permutation.dtype)[:, None],
                  dtype=dtype))
    k = min(m, n)
    l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
    u = jnp.triu(lu)[:k, :]
    if permute_l:
        return jnp.matmul(p, l), u
    else:
        return p, l, u
Beispiel #19
0
def sph_harm(m: jnp.ndarray,
             n: jnp.ndarray,
             theta: jnp.ndarray,
             phi: jnp.ndarray,
             n_max: Optional[int] = None) -> jnp.ndarray:
    r"""Computes the spherical harmonics.

  The JAX version has one extra argument `n_max`, the maximum value in `n`.

  The spherical harmonic of degree `n` and order `m` can be written as
  :math:`Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \phi) * \exp(i m \theta)`,
  where :math:`N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!}
  {4 \pi \left(n+m\right)!}}` is the normalization factor and :math:`\phi` and
  :math:\theta` are the colatitude and longitude, repectively. :math:`N_n^m` is
  chosen in the way that the spherical harmonics form a set of orthonormal basis
  functions of :math:`L^2(S^2)`.

  Args:
    m: The order of the harmonic; must have `|m| <= n`. Return values for
      `|m| > n` ara undefined.
    n: The degree of the harmonic; must have `n >= 0`. The standard notation for
      degree in descriptions of spherical harmonics is `l (lower case L)`. We
      use `n` here to be consistent with `scipy.special.sph_harm`. Return
      values for `n < 0` are undefined.
    theta: The azimuthal (longitudinal) coordinate; must be in [0, 2*pi].
    phi: The polar (colatitudinal) coordinate; must be in [0, pi].
    n_max: The maximum degree `max(n)`. If the supplied `n_max` is not the true
      maximum value of `n`, the results are clipped to `n_max`. For example,
      `sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6)`
      acutually returns
      `sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)`
  Returns:
    A 1D array containing the spherical harmonics at (m, n, theta, phi).
  """

    if jnp.isscalar(phi):
        phi = jnp.array([phi])

    if n_max is None:
        n_max = jnp.max(n)
    n_max = core.concrete_or_error(
        int, n_max,
        'The `n_max` argument of `jnp.scipy.special.sph_harm` must '
        'be statically specified to use `sph_harm` within JAX transformations.'
    )

    return _sph_harm(m, n, theta, phi, n_max)
Beispiel #20
0
def roots(p, *, strip_zeros=True):
    _check_arraylike("roots", p)
    p = atleast_1d(*_promote_dtypes_inexact(p))
    if p.ndim != 1:
        raise ValueError("Input must be a rank-1 array.")
    if p.size < 2:
        return array([], dtype=dtypes._to_complex_dtype(p.dtype))
    num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0))

    if strip_zeros:
        num_leading_zeros = core.concrete_or_error(
            int, num_leading_zeros,
            "The error occurred in the jnp.roots() function. To use this within a "
            "JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros "
            "will be result in some returned roots being set to NaN.")
        return _roots_no_zeros(p[num_leading_zeros:])
    else:
        return _roots_with_zeros(p, num_leading_zeros)
Beispiel #21
0
def _update_slice(operand, update, start_indices, update_dims):
    """
  Similar to lax.dynamic_update_slice, but handles padded updates where padding
  values should not overwrite existing values in the array.

  Args:
  operand: the array to update
  update: the padded array to write
  start_indices: the offset at which to write `update`.
  update_dims: the true dimensions of the padded update `update`. Only values
    inside the rectangle given by `update_dims` will be overwritten."""
    operand_shape = operand.shape
    operand = lax.pad(operand, jnp.array(0, operand.dtype),
                      [(0, d, 0) for d in update.shape])
    start_indices = tuple(jnp.int32(i) for i in start_indices)
    t = lax.dynamic_slice(operand, start_indices, update.shape)
    t = _mask(update, update_dims, t)
    operand = lax.dynamic_update_slice(operand, t, start_indices)
    return lax.slice(operand, [0] * operand.ndim, operand_shape)
Beispiel #22
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    else:
        a, = _promote_args_inexact("logsumexp", a)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    # fast path if the result cannot be negative.
    if b is None and not np.issubdtype(a.dtype, np.complexfloating):
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), out, 1.0)
        sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype)
    else:
        expsub = lax.exp(lax.sub(a, amax_with_dims))
        if b is not None:
            expsub = lax.mul(expsub, b)
        sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims)

        sign = lax.stop_gradient(jnp.sign(sumexp))
        if np.issubdtype(sumexp.dtype, np.complexfloating):
            if return_sign:
                sumexp = sign * sumexp
            out = lax.add(lax.log(sumexp), amax)
        else:
            out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        if not np.issubdtype(out.dtype, np.complexfloating):
            # Use jnp.array(nan) to avoid false positives in debug_nans
            # (see https://github.com/google/jax/issues/7634)
            out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out)
    return out
Beispiel #23
0
def _expint1(x):
    # 0 < x <= 2
    A = [
        -5.350447357812542947283e0,
        2.185049168816613393830e2,
        -4.176572384826693777058e3,
        5.541176756393557601232e4,
        -3.313381331178144034309e5,
        1.592627163384945414220e6,
    ]
    B = [
        1.0,
        -5.250547959112862969197e1,
        1.259616186786790571525e3,
        -1.756549581973534652631e4,
        1.493062117002725991967e5,
        -7.294949239640527645655e5,
        1.592627163384945429726e6,
    ]
    A, B = [jnp.array(U, dtype=x.dtype) for U in [A, B]]
    f = jnp.polyval(A, x) / jnp.polyval(B, x)
    return x * f + jnp.euler_gamma + jnp.log(x)
Beispiel #24
0
def _lstsq(a, b, rcond, *, numpy_resid=False):
  # TODO: add lstsq to lax_linalg and implement this function via those wrappers.
  # TODO: add custom jvp rule for more robust lstsq differentiation
  a, b = _promote_dtypes_inexact(a, b)
  if a.shape[0] != b.shape[0]:
    raise ValueError("Leading dimensions of input arrays must match")
  b_orig_ndim = b.ndim
  if b_orig_ndim == 1:
    b = b[:, None]
  if a.ndim != 2:
    raise TypeError(
      f"{a.ndim}-dimensional array given. Array must be two-dimensional")
  if b.ndim != 2:
    raise TypeError(
      f"{b.ndim}-dimensional array given. Array must be one or two-dimensional")
  m, n = a.shape
  dtype = a.dtype
  if rcond is None:
    rcond = jnp.finfo(dtype).eps * max(n, m)
  else:
    rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
  u, s, vt = svd(a, full_matrices=False)
  mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0]
  rank = mask.sum()
  safe_s = jnp.where(mask, s, 1).astype(a.dtype)
  s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
  uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
  x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
  # Numpy returns empty residuals in some cases. To allow compilation, we
  # default to returning full residuals in all cases.
  if numpy_resid and (rank < n or m <= n):
    resid = jnp.asarray([])
  else:
    b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST)
    resid = norm(b - b_estimate, axis=0) ** 2
  if b_orig_ndim == 1:
    x = x.ravel()
  return x, resid, rank, s
Beispiel #25
0
def poly(seq_of_zeros):
    _check_arraylike('poly', seq_of_zeros)
    seq_of_zeros, = _promote_dtypes_inexact(seq_of_zeros)
    seq_of_zeros = atleast_1d(seq_of_zeros)

    sh = seq_of_zeros.shape
    if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0:
        # import at runtime to avoid circular import
        from jax._src.numpy import linalg
        seq_of_zeros = linalg.eigvals(seq_of_zeros)

    if seq_of_zeros.ndim != 1:
        raise ValueError("input must be 1d or non-empty square 2d array.")

    dt = seq_of_zeros.dtype
    if len(seq_of_zeros) == 0:
        return ones((), dtype=dt)

    a = ones((1, ), dtype=dt)
    for k in range(len(seq_of_zeros)):
        a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full')

    return a
Beispiel #26
0
def _expn1(n, x):
    # exponential integral En
    _c = _constant_like
    x = jnp.array(x)
    MACHEP = jnp.finfo(x.dtype).eps

    zero = _c(x, 0.0)
    one = _c(x, 1.0)
    psi = -jnp.euler_gamma - jnp.log(x)
    psi = lax.fori_loop(_c(n, 1), n, lambda i, psi: psi + one / i, psi)
    n1 = jnp.where(n == _c(n, 1), one + one, n)
    init = dict(
        x=x,
        z=-x,
        xk=zero,
        yk=one,
        pk=one - n,
        ans=jnp.where(n == _c(n, 1), zero, one / (one - n1)),
        t=jnp.inf,
    )

    def body(d):
        d["xk"] += one
        d["yk"] *= d["z"] / d["xk"]
        d["pk"] += one
        d["ans"] += jnp.where(d["pk"] != zero, d["yk"] / d["pk"], zero)
        d["t"] = jnp.where(d["ans"] != zero, abs(d["yk"] / d["ans"]), one)
        return d

    def cond(d):
        return (d["x"] > _c(d["x"], 0.0)) & (d["t"] > MACHEP)

    d = lax.while_loop(cond, body, init)
    t = n
    r = n - _c(n, 1)
    return d["z"]**r * psi / jnp.exp(gammaln(t)) - d["ans"]
Beispiel #27
0
def all_gather(x, axis_name, *, axis_index_groups=None):
  """Gather values of x across all replicas.

  If ``x`` is a pytree then the result is equivalent to mapping this function to
  each leaf in the tree.

  This is equivalent to, but faster than, all_to_all(broadcast(x)).

  Args:
    x: array(s) with a mapped axis named ``axis_name``.
    axis_name: hashable Python object used to name a pmapped axis (see the
      :func:`jax.pmap` documentation for more details).
    axis_index_groups: optional list of lists containing axis indices (e.g. for
      an axis of size 4, [[0, 1], [2, 3]] would run all gather over the first
      two and last two replicas). Groups must cover all axis indices exactly
      once, and all groups must be the same size.

  Returns:
    Array(s) representing the result of an all-gather along the axis
    ``axis_name``. Shapes are the same as ``x.shape``, but with a leading
    dimension of the axis_size.

  For example, with 4 XLA devices available:

  >>> x = np.arange(4)
  >>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x)
  >>> print(y)
  [[0 1 2 3]
   [0 1 2 3]
   [0 1 2 3]
   [0 1 2 3]]

  An example of using axis_index_groups, groups split by even & odd device ids:

  >>> x = np.arange(16).reshape(4, 4)
  >>> print(x)
  [[ 0.  1.  2.  3.]
   [ 4.  5.  6.  7.]
   [ 8.  9. 10. 11.]
   [12. 13. 14. 15.]]
  >>> y = jax.pmap(lambda x: jax.lax.all_gather(
  ... x, 'i', axis_index_groups=[[0, 2], [3, 1]]))(x)
  >>> print(y)
  [[[ 0.  1.  2.  3.]
    [ 8.  9. 10. 11.]]
   [[12. 13. 14. 15.]
    [ 4.  5.  6.  7.]]
   [[ 0.  1.  2.  3.]
    [ 8.  9. 10. 11.]]
   [[12. 13. 14. 15.]
    [ 4.  5.  6.  7.]]
  """

  index = axis_index(axis_name)
  if axis_index_groups is not None:
    indices = np.array(axis_index_groups).flatten()
    axis_index_to_group_index = indices.argsort() % len(axis_index_groups[0])
    index = lax_numpy.array(axis_index_to_group_index)[index]

  axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups)

  return _allgather(x, 0, axis_size, index, axis_name, axis_index_groups)
Beispiel #28
0
def eigh_tridiagonal(d,
                     e,
                     *,
                     eigvals_only=False,
                     select='a',
                     select_range=None,
                     tol=None):
    if not eigvals_only:
        raise NotImplementedError(
            "Calculation of eigenvectors is not implemented")

    def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x):
        """Implements the Sturm sequence recurrence."""
        n = alpha.shape[0]
        zeros = jnp.zeros(x.shape, dtype=jnp.int32)
        ones = jnp.ones(x.shape, dtype=jnp.int32)

        # The first step in the Sturm sequence recurrence
        # requires special care if x is equal to alpha[0].
        def sturm_step0():
            q = alpha[0] - x
            count = jnp.where(q < 0, ones, zeros)
            q = jnp.where(alpha[0] == x, alpha0_perturbation, q)
            return q, count

        # Subsequent steps all take this form:
        def sturm_step(i, q, count):
            q = alpha[i] - beta_sq[i - 1] / q - x
            count = jnp.where(q <= pivmin, count + 1, count)
            q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q)
            return q, count

        # The first step initializes q and count.
        q, count = sturm_step0()

        # Peel off ((n-1) % blocksize) steps from the main loop, so we can run
        # the bulk of the iterations unrolled by a factor of blocksize.
        blocksize = 16
        i = 1
        peel = (n - 1) % blocksize
        unroll_cnt = peel

        def unrolled_steps(args):
            start, q, count = args
            for j in range(unroll_cnt):
                q, count = sturm_step(start + j, q, count)
            return start + unroll_cnt, q, count

        i, q, count = unrolled_steps((i, q, count))

        # Run the remaining steps of the Sturm sequence using a partially
        # unrolled while loop.
        unroll_cnt = blocksize

        def cond(iqc):
            i, q, count = iqc
            return jnp.less(i, n)

        _, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count))
        return count

    alpha = jnp.asarray(d)
    beta = jnp.asarray(e)
    supported_dtypes = (jnp.float32, jnp.float64, jnp.complex64,
                        jnp.complex128)
    if alpha.dtype != beta.dtype:
        raise TypeError(
            "diagonal and off-diagonal values must have same dtype, "
            f"got {alpha.dtype} and {beta.dtype}")
    if alpha.dtype not in supported_dtypes or beta.dtype not in supported_dtypes:
        raise TypeError(
            "Only float32 and float64 inputs are supported as inputs "
            "to jax.scipy.linalg.eigh_tridiagonal, got "
            f"{alpha.dtype} and {beta.dtype}")
    n = alpha.shape[0]
    if n <= 1:
        return jnp.real(alpha)

    if jnp.issubdtype(alpha.dtype, jnp.complexfloating):
        alpha = jnp.real(alpha)
        beta_sq = jnp.real(beta * jnp.conj(beta))
        beta_abs = jnp.sqrt(beta_sq)
    else:
        beta_abs = jnp.abs(beta)
        beta_sq = jnp.square(beta)

    # Estimate the largest and smallest eigenvalues of T using the Gershgorin
    # circle theorem.
    off_diag_abs_row_sum = jnp.concatenate(
        [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0)
    lambda_est_max = jnp.amax(alpha + off_diag_abs_row_sum)
    lambda_est_min = jnp.amin(alpha - off_diag_abs_row_sum)
    # Upper bound on 2-norm of T.
    t_norm = jnp.maximum(jnp.abs(lambda_est_min), jnp.abs(lambda_est_max))

    # Compute the smallest allowed pivot in the Sturm sequence to avoid
    # overflow.
    finfo = np.finfo(alpha.dtype)
    one = np.ones([], dtype=alpha.dtype)
    safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny)
    pivmin = safemin * jnp.maximum(1, jnp.amax(beta_sq))
    alpha0_perturbation = jnp.square(finfo.eps * beta_abs[0])
    abs_tol = finfo.eps * t_norm
    if tol is not None:
        abs_tol = jnp.maximum(tol, abs_tol)

    # In the worst case, when the absolute tolerance is eps*lambda_est_max and
    # lambda_est_max = -lambda_est_min, we have to take as many bisection steps
    # as there are bits in the mantissa plus 1.
    # The proof is left as an exercise to the reader.
    max_it = finfo.nmant + 1

    # Determine the indices of the desired eigenvalues, based on select and
    # select_range.
    if select == 'a':
        target_counts = jnp.arange(n, dtype=jnp.int32)
    elif select == 'i':
        if select_range[0] > select_range[1]:
            raise ValueError('Got empty index range in select_range.')
        target_counts = jnp.arange(select_range[0],
                                   select_range[1] + 1,
                                   dtype=jnp.int32)
    elif select == 'v':
        # TODO(phawkins): requires dynamic shape support.
        raise NotImplementedError("eigh_tridiagonal(..., select='v') is not "
                                  "implemented")
    else:
        raise ValueError("'select must have a value in {'a', 'i', 'v'}.")

    # Run binary search for all desired eigenvalues in parallel, starting from
    # the interval lightly wider than the estimated
    # [lambda_est_min, lambda_est_max].
    fudge = 2.1  # We widen starting interval the Gershgorin interval a bit.
    norm_slack = jnp.array(n, alpha.dtype) * fudge * finfo.eps * t_norm
    lower = lambda_est_min - norm_slack - 2 * fudge * pivmin
    upper = lambda_est_max + norm_slack + fudge * pivmin

    # Pre-broadcast the scalars used in the Sturm sequence for improved
    # performance.
    target_shape = jnp.shape(target_counts)
    lower = jnp.broadcast_to(lower, shape=target_shape)
    upper = jnp.broadcast_to(upper, shape=target_shape)
    mid = 0.5 * (upper + lower)
    pivmin = jnp.broadcast_to(pivmin, target_shape)
    alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape)

    # Start parallel binary searches.
    def cond(args):
        i, lower, _, upper = args
        return jnp.logical_and(jnp.less(i, max_it),
                               jnp.less(abs_tol, jnp.amax(upper - lower)))

    def body(args):
        i, lower, mid, upper = args
        counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid)
        lower = jnp.where(counts <= target_counts, mid, lower)
        upper = jnp.where(counts > target_counts, mid, upper)
        mid = 0.5 * (lower + upper)
        return i + 1, lower, mid, upper

    _, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper))
    return mid
Beispiel #29
0
def ppf(q, loc=0, scale=1):
    return jnp.array(special.ndtri(q) * scale + loc, 'float64')
Beispiel #30
0
def _eigh_work(H, n, termination_size=256):
  """ The main work loop performing the symmetric eigendecomposition of H.
  Each step recursively computes a projector into the space of eigenvalues
  above jnp.mean(jnp.diag(H)). The result of the projections into and out of
  that space, along with the isometries accomplishing these, are then computed.
  This is performed recursively until the projections have size 1, and thus
  store an eigenvalue of the original input; the corresponding isometry is
  the related eigenvector. The results are then composed.

  This function cannot be Jitted because the internal split_spectrum cannot
  be.

  Args:
    H: The Hermitian input.
    n: The true (dynamic) shape of H.

  Returns:
    H, V: The result of the projection.
  """
  # We turn what was originally a recursive algorithm into an iterative
  # algorithm with an explicit stack.
  N, _ = H.shape
  n = jnp.asarray(n, jnp.int32)
  agenda = Stack.create(
    N + 1, _Subproblem(jnp.array(0, jnp.int32), jnp.array(0, jnp.int32)))
  agenda = agenda.push(_Subproblem(offset=jnp.int32(0), size=n))

  # eigenvectors is the array in which we build the output eigenvectors.
  # We initialize it with the identity matrix so the initial matrix
  # multiplications in_split_spectrum_jittable are the identity.
  eigenvectors = jnp.eye(N, dtype=H.dtype)

  # blocks is an array representing a stack of Hermitian matrix blocks that we
  # need to recursively decompose. Subproblems are different sizes, so the stack
  # of blocks is ragged. Subproblems are left-aligned (i.e. starting at the 0th
  # column). Here is an ASCII art picture of three blocks A, B, C, embedded
  # in the larger `blocks` workspace (represented with trailing dots).
  #
  # A A A . . .
  # A A A . . .
  # A A A . . .
  # B B . . . .
  # B B . . . .
  # C C C C . .
  # C C C C . .
  # C C C C . .
  # C C C C . .
  #
  # Each step of the algorithm subdivides a block into two subblocks whose
  # sizes sum to the original block size. We overwrite the original block with
  # those two subblocks so we don't need any additional scratch space.
  #
  # At termination, "blocks" will contain 1x1 blocks (i.e., the eigenvalues) in
  # its first column.
  blocks = H

  def base_case(B, offset, b, agenda, blocks, eigenvectors):
    # Base case: for blocks under a minimum size, we cutoff the recursion
    # and call the TPU Jacobi eigendecomposition implementation. The Jacobi
    # algorithm works well for small matrices but scales poorly, so the two
    # complement each other well.
    H = _slice(blocks, (offset, 0), (b, b), (B, B))
    V = _slice(eigenvectors, (0, offset), (n, b), (N, B))

    # We replace the masked-out part of the matrix with the identity matrix.
    # We know that the TPU Jacobi eigh implementation will not alter the order
    # of the eigenvalues, so we know the eigendecomposition of the original
    # matrix is in the top-left corner of the eigendecomposition of the padded
    # matrix.
    # It is very important that the underlying eigh implementation does not sort
    # the eigenvalues for this reason! This is currently not true of JAX's CPU
    # and GPU eigendecompositions, and for those platforms this algorithm will
    # only do the right thing if termination_size == 1.
    H = _mask(H, (b, b), jnp.eye(B, dtype=H.dtype))
    eig_vecs, eig_vals = lax.linalg.eigh(H, sort_eigenvalues=False)
    eig_vecs = _mask(eig_vecs, (b, b))
    eig_vals = _mask(eig_vals, (b,))
    eig_vecs = jnp.dot(V, eig_vecs)

    blocks = _update_slice(blocks, eig_vals[:, None], (offset, 0), (b, b))
    eigenvectors = _update_slice(eigenvectors, eig_vecs, (0, offset), (n, b))
    return agenda, blocks, eigenvectors

  def recursive_case(B, offset, b, agenda, blocks, eigenvectors):
    # The recursive case of the algorithm, specialized to a static block size
    # of B.
    H = _slice(blocks, (offset, 0), (b, b), (B, B))
    V = _slice(eigenvectors, (0, offset), (n, b), (N, B))

    split_point = jnp.nanmedian(_mask(jnp.diag(H), (b,), jnp.nan))  # TODO: Improve this?
    H_minus, V_minus, H_plus, V_plus, rank = split_spectrum(H, b, split_point, V0=V)

    blocks = _update_slice(blocks, H_minus, (offset, 0), (rank, rank))
    blocks = _update_slice(blocks, H_plus, (offset + rank, 0), (b - rank, b - rank))
    eigenvectors = _update_slice(eigenvectors, V_minus, (0, offset), (n, rank))
    eigenvectors = _update_slice(eigenvectors, V_plus, (0, offset + rank),
                                 (n, b - rank))

    agenda = agenda.push(_Subproblem(offset + rank, (b - rank)))
    agenda = agenda.push(_Subproblem(offset, rank))
    return agenda, blocks, eigenvectors

  def loop_cond(state):
    agenda, _, _ = state
    return ~agenda.empty()

  # It would be wasteful to perform all computation padded up to the original
  # matrix size. Instead, we form buckets of padded sizes e.g.,
  # [256, 512, 1024, ..., N], aiming for a balance between compilation time
  # and runtime.
  cutoff = min(N, termination_size)
  buckets = [cutoff]
  branches = [partial(base_case, cutoff)]
  i = cutoff
  while i < N:
    i = min(2 * i, N)
    buckets.append(i)
    branches.append(partial(recursive_case, i))
  buckets = jnp.array(buckets)

  def loop_body(state):
    agenda, blocks, eigenvectors = state
    (offset, b), agenda = agenda.pop()

    which = jnp.where(buckets < b, jnp.iinfo(jnp.int32).max, buckets)
    choice = jnp.argmin(which)
    return lax.switch(choice, branches, offset, b, agenda, blocks, eigenvectors)

  _, blocks, eigenvectors = lax.while_loop(
      loop_cond, loop_body, (agenda, blocks, eigenvectors))
  return blocks[:, 0], eigenvectors