Exemplo n.º 1
0
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
  A, = primals
  dA, = tangents
  s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)

  if compute_uv and full_matrices:
    # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
    raise NotImplementedError(
      "Singular value decomposition JVP not implemented for full matrices")

  k = s.shape[-1]
  Ut, V = _H(U), _H(Vt)
  s_dim = s[..., None, :]
  dS = jnp.matmul(jnp.matmul(Ut, dA), V)
  ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))

  if not compute_uv:
    return (s,), (ds,)

  F = 1 / (jnp.square(s_dim) - jnp.square(_T(s_dim)) + jnp.eye(k, dtype=A.dtype))
  F = F - jnp.eye(k, dtype=A.dtype)
  dSS = s_dim * dS
  SdS = _T(s_dim) * dS
  dU = jnp.matmul(U, F * (dSS + _T(dSS)))
  dV = jnp.matmul(V, F * (SdS + _T(SdS)))

  m, n = A.shape[-2:]
  if m > n:
    dU = dU + jnp.matmul(jnp.eye(m, dtype=A.dtype) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim
  if n > m:
    dV = dV + jnp.matmul(jnp.eye(n, dtype=A.dtype) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim
  return (s, U, Vt), (ds, dU, _T(dV))
Exemplo n.º 2
0
def _lu_jvp_rule(primals, tangents):
    a, = primals
    a_dot, = tangents
    lu, pivots, permutation = lu_p.bind(a)

    a_shape = jnp.shape(a)
    m, n = a_shape[-2:]
    dtype = lax.dtype(a)
    k = min(m, n)

    batch_dims = a_shape[:-2]
    iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, )))
    x = a_dot[iotas[:-1] + (permutation, slice(None))]

    # Differentiation of Matrix Functionals Using Triangular Factorization
    # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
    #
    #     LU = A
    # ==> L'U + LU' = A'
    # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
    # ==> L' = L . tril(inv(L) . A' . inv(U), -1)
    #     U' = triu(inv(L) . A' . inv(U)) . U

    ndims = len(a_shape)
    l_padding = [(0, 0, 0)] * ndims
    l_padding[-1] = (0, m - k, 0)
    zero = jnp._constant_like(lu, 0)
    l = lax.pad(jnp.tril(lu[..., :, :k], -1), zero, l_padding)
    l = l + jnp.eye(m, m, dtype=dtype)

    u_eye = lax.pad(jnp.eye(n - k, n - k, dtype=dtype), zero,
                    ((k, 0, 0), (k, 0, 0)))
    u_padding = [(0, 0, 0)] * ndims
    u_padding[-2] = (0, n - k, 0)
    u = lax.pad(jnp.triu(lu[..., :k, :]), zero, u_padding) + u_eye

    la = triangular_solve(l,
                          x,
                          left_side=True,
                          transpose_a=False,
                          lower=True,
                          unit_diagonal=True)
    lau = triangular_solve(u,
                           la,
                           left_side=False,
                           transpose_a=False,
                           lower=False)

    l_dot = jnp.matmul(l, jnp.tril(lau, -1))
    u_dot = jnp.matmul(jnp.triu(lau), u)
    lu_dot = l_dot + u_dot
    return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_value(pivots),
                                       ad_util.Zero.from_value(permutation))
Exemplo n.º 3
0
def _pinv_jvp(rcond, primals, tangents):
    # The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
    # Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
    # Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
    # (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
    a, = primals
    a_dot, = tangents
    p = pinv(a, rcond=rcond)
    m, n = a.shape[-2:]
    # TODO(phawkins): on TPU, we would need to opt into high precision here.
    # TODO(phawkins): consider if this can be simplified in the Hermitian case.
    p_dot = -p @ a_dot @ p
    p_dot = p_dot + p @ _H(p) @ _H(a_dot) @ (jnp.eye(m, dtype=a.dtype) - a @ p)
    p_dot = p_dot + (jnp.eye(n, dtype=a.dtype) - p @ a) @ _H(a_dot) @ _H(p) @ p
    return p, p_dot
Exemplo n.º 4
0
def eigh_jvp_rule(primals, tangents, lower):
    # Derivative for eigh in the simplest case of distinct eigenvalues.
    # This is classic nondegenerate perurbation theory, but also see
    # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
    # The general solution treating the case of degenerate eigenvalues is
    # considerably more complicated. Ambitious readers may refer to the general
    # methods below or refer to degenerate perturbation theory in physics.
    # https://www.win.tue.nl/analysis/reports/rana06-33.pdf and
    # https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf
    a, = primals
    a_dot, = tangents

    v, w_real = eigh_p.bind(symmetrize(a), lower=lower)

    # for complex numbers we need eigenvalues to be full dtype of v, a:
    w = w_real.astype(a.dtype)
    eye_n = jnp.eye(a.shape[-1], dtype=a.dtype)
    # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
    Fmat = jnp.reciprocal(eye_n + w[..., jnp.newaxis, :] -
                          w[..., jnp.newaxis]) - eye_n
    # eigh impl doesn't support batch dims, but future-proof the grad.
    dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul,
                  precision=lax.Precision.HIGHEST)
    vdag_adot_v = dot(dot(_H(v), a_dot), v)
    dv = dot(v, jnp.multiply(Fmat, vdag_adot_v))
    dw = jnp.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1))
    return (v, w_real), (dv, dw)
Exemplo n.º 5
0
Arquivo: eigh.py Projeto: cloudhan/jax
  def base_case(B, offset, b, agenda, blocks, eigenvectors):
    # Base case: for blocks under a minimum size, we cutoff the recursion
    # and call the TPU Jacobi eigendecomposition implementation. The Jacobi
    # algorithm works well for small matrices but scales poorly, so the two
    # complement each other well.
    H = _slice(blocks, (offset, 0), (b, b), (B, B))
    V = _slice(eigenvectors, (0, offset), (n, b), (N, B))

    # We replace the masked-out part of the matrix with the identity matrix.
    # We know that the TPU Jacobi eigh implementation will not alter the order
    # of the eigenvalues, so we know the eigendecomposition of the original
    # matrix is in the top-left corner of the eigendecomposition of the padded
    # matrix.
    # It is very important that the underlying eigh implementation does not sort
    # the eigenvalues for this reason! This is currently not true of JAX's CPU
    # and GPU eigendecompositions, and for those platforms this algorithm will
    # only do the right thing if termination_size == 1.
    H = _mask(H, (b, b), jnp.eye(B, dtype=H.dtype))
    eig_vecs, eig_vals = lax.linalg.eigh(H, sort_eigenvalues=False)
    eig_vecs = _mask(eig_vecs, (b, b))
    eig_vals = _mask(eig_vals, (b,))
    eig_vecs = jnp.dot(V, eig_vecs)

    blocks = _update_slice(blocks, eig_vals[:, None], (offset, 0), (b, b))
    eigenvectors = _update_slice(eigenvectors, eig_vecs, (0, offset), (n, b))
    return agenda, blocks, eigenvectors
Exemplo n.º 6
0
def _pade3(A):
    b = (120., 60., 12., 1.)
    ident = jnp.eye(*A.shape, dtype=A.dtype)
    A2 = _precise_dot(A, A)
    U = _precise_dot(A, (b[3] * A2 + b[1] * ident))
    V = b[2] * A2 + b[0] * ident
    return U, V
Exemplo n.º 7
0
def inv(a):
    if jnp.ndim(a) < 2 or a.shape[-1] != a.shape[-2]:
        raise ValueError(
            f"Argument to inv must have shape [..., n, n], got {a.shape}.")
    return solve(
        a, lax.broadcast(jnp.eye(a.shape[-1], dtype=lax.dtype(a)),
                         a.shape[:-2]))
Exemplo n.º 8
0
def matrix_power(a, n):
    a = _promote_arg_dtypes(jnp.asarray(a))

    if a.ndim < 2:
        raise TypeError("{}-dimensional array given. Array must be at least "
                        "two-dimensional".format(a.ndim))
    if a.shape[-2] != a.shape[-1]:
        raise TypeError("Last 2 dimensions of the array must be square")
    try:
        n = operator.index(n)
    except TypeError as err:
        raise TypeError(
            "exponent must be an integer, got {}".format(n)) from err

    if n == 0:
        return jnp.broadcast_to(jnp.eye(a.shape[-2], dtype=a.dtype), a.shape)
    elif n < 0:
        a = inv(a)
        n = np.abs(n)

    if n == 1:
        return a
    elif n == 2:
        return a @ a
    elif n == 3:
        return (a @ a) @ a

    z = result = None
    while n > 0:
        z = a if z is None else (z @ z)
        n, bit = divmod(n, 2)
        if bit:
            result = z if result is None else (result @ z)

    return result
Exemplo n.º 9
0
def _pade5(A):
    b = (30240., 15120., 3360., 420., 30., 1.)
    ident = jnp.eye(*A.shape, dtype=A.dtype)
    A2 = _precise_dot(A, A)
    A4 = _precise_dot(A2, A2)
    U = _precise_dot(A, b[5] * A4 + b[3] * A2 + b[1] * ident)
    V = b[4] * A4 + b[2] * A2 + b[0] * ident
    return U, V
Exemplo n.º 10
0
def _pade7(A):
    b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.)
    ident = jnp.eye(*A.shape, dtype=A.dtype)
    A2 = _precise_dot(A, A)
    A4 = _precise_dot(A2, A2)
    A6 = _precise_dot(A4, A2)
    U = _precise_dot(A, b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident)
    V = b[6] * A6 + b[4] * A4 + b[2] * A2 + b[0] * ident
    return U, V
Exemplo n.º 11
0
def _pade13(A):
  b = (64764752532480000., 32382376266240000., 7771770303897600.,
       1187353796428800., 129060195264000., 10559470521600., 670442572800.,
       33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.)
  ident = jnp.eye(*A.shape, dtype=A.dtype)
  A2 = _precise_dot(A, A)
  A4 = _precise_dot(A2, A2)
  A6 = _precise_dot(A4, A2)
  U = _precise_dot(A, _precise_dot(A6, b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
  V = _precise_dot(A6, b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
  return U,V
Exemplo n.º 12
0
def _pade9(A):
  b = (17643225600., 8821612800., 2075673600., 302702400., 30270240.,
       2162160., 110880., 3960., 90., 1.)
  ident = jnp.eye(*A.shape, dtype=A.dtype)
  A2 = _precise_dot(A, A)
  A4 = _precise_dot(A2, A2)
  A6 = _precise_dot(A4, A2)
  A8 = _precise_dot(A6, A2)
  U = _precise_dot(A, b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
  V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
  return U,V
Exemplo n.º 13
0
def _lu(a, permute_l):
    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    lu, pivots, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    m, n = jnp.shape(a)
    p = jnp.real(jnp.array(permutation == jnp.arange(m)[:, None], dtype=dtype))
    k = min(m, n)
    l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
    u = jnp.triu(lu)[:k, :]
    if permute_l:
        return jnp.matmul(p, l), u
    else:
        return p, l, u
Exemplo n.º 14
0
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
  A, = primals
  dA, = tangents
  s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)

  if compute_uv and full_matrices:
    # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
    raise NotImplementedError(
      "Singular value decomposition JVP not implemented for full matrices")

  Ut, V = _H(U), _H(Vt)
  s_dim = s[..., None, :]
  dS = jnp.matmul(jnp.matmul(Ut, dA), V)
  ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))

  if not compute_uv:
    return (s,), (ds,)

  s_diffs = jnp.square(s_dim) - jnp.square(_T(s_dim))
  s_diffs_zeros = jnp.eye(s.shape[-1], dtype=A.dtype)  # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.)  # is 1. where s_diffs is 0. and is 0. everywhere else
  F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros
  dSS = s_dim * dS  # dS.dot(jnp.diag(s))
  SdS = _T(s_dim) * dS  # jnp.diag(s).dot(dS)

  s_zeros = jnp.ones((), dtype=A.dtype) * (s == 0.)
  s_inv = 1 / (s + s_zeros) - s_zeros
  s_inv_mat = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(s_inv)
  dUdV_diag = .5 * (dS - _H(dS)) * s_inv_mat
  dU = jnp.matmul(U, F * (dSS + _H(dSS)) + dUdV_diag)
  dV = jnp.matmul(V, F * (SdS + _H(SdS)))

  m, n = A.shape[-2:]
  if m > n:
    dU = dU + jnp.matmul(jnp.eye(m, dtype=A.dtype) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim
  if n > m:
    dV = dV + jnp.matmul(jnp.eye(n, dtype=A.dtype) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim

  return (s, U, Vt), (ds, dU, _H(dV))
Exemplo n.º 15
0
def split_spectrum(H, n, split_point, V0=None):
    """ The Hermitian matrix `H` is split into two matrices `H_minus`
  `H_plus`, respectively sharing its eigenspaces beneath and above
  its `split_point`th eigenvalue.

  Returns, in addition, `V_minus` and `V_plus`, isometries such that
  `Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are
  returned instead; this allows the overall isometries mapping from
  an initial input matrix to progressively smaller blocks to be formed.

  Args:
    H: The Hermitian matrix to split.
    split_point: The eigenvalue to split along.
    V0: Matrix of isometries to be updated.
  Returns:
    H_minus: A Hermitian matrix sharing the eigenvalues of `H` beneath
      `split_point`.
    V_minus: An isometry from the input space of `V0` to `H_minus`.
    H_plus: A Hermitian matrix sharing the eigenvalues of `H` above
      `split_point`.
    V_plus: An isometry from the input space of `V0` to `H_plus`.
    rank: The dynamic size of the m subblock.
  """
    N, _ = H.shape
    H_shift = H - (split_point * jnp.eye(N, dtype=split_point.dtype)).astype(
        H.dtype)
    U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True, dynamic_shape=(n, n))
    P = -0.5 * (U - _mask(jnp.eye(N, dtype=H.dtype), (n, n)))
    rank = jnp.round(jnp.trace(jnp.real(P))).astype(jnp.int32)

    V_minus, V_plus = _projector_subspace(P, H, n, rank)
    H_minus = (V_minus.conj().T @ H) @ V_minus
    H_plus = (V_plus.conj().T @ H) @ V_plus
    if V0 is not None:
        V_minus = jnp.dot(V0, V_minus)
        V_plus = jnp.dot(V0, V_plus)
    return H_minus, V_minus, H_plus, V_plus, rank
Exemplo n.º 16
0
def _empty_svd(a, *, full_matrices, compute_uv):
  batch_shape = a.shape[:-2]
  m, n = a.shape[-2:]
  s = jnp.empty(batch_shape + (0,), dtype=lax_internal._complex_basetype(a.dtype))
  if not compute_uv:
    return (s,)
  if full_matrices:
    size = max(m, n)
    u = jnp.broadcast_to(jnp.eye(size, dtype=a.dtype), batch_shape + (size, size))
  else:
    u = jnp.empty(batch_shape + (m, n), dtype=a.dtype)
  v = jnp.empty(batch_shape + (0, 0), dtype=a.dtype)
  if m < n:
    u, v = v, u
  return s, u, v
Exemplo n.º 17
0
def _lu(a, permute_l):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    lu, _, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    m, n = jnp.shape(a)
    p = jnp.real(
        jnp.array(permutation[None, :] == jnp.arange(
            m, dtype=permutation.dtype)[:, None],
                  dtype=dtype))
    k = min(m, n)
    l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
    u = jnp.triu(lu)[:k, :]
    if permute_l:
        return jnp.matmul(p, l), u
    else:
        return p, l, u
Exemplo n.º 18
0
def qr_jvp_rule(primals, tangents, full_matrices):
    # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
    x, = primals
    dx, = tangents
    q, r = qr_p.bind(x, full_matrices=False)
    *_, m, n = x.shape
    if full_matrices or m < n:
        raise NotImplementedError(
            "Unimplemented case of QR decomposition derivative")
    dx_rinv = triangular_solve(r, dx)  # Right side solve by default
    qt_dx_rinv = jnp.matmul(_H(q), dx_rinv)
    qt_dx_rinv_lower = jnp.tril(qt_dx_rinv, -1)
    do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower)  # This is skew-symmetric
    # The following correction is necessary for complex inputs
    do = do + jnp.eye(n, dtype=do.dtype) * (qt_dx_rinv - jnp.real(qt_dx_rinv))
    dq = jnp.matmul(q, do - qt_dx_rinv) + dx_rinv
    dr = jnp.matmul(qt_dx_rinv - do, r)
    return (q, r), (dq, dr)
Exemplo n.º 19
0
def eigh(H,
         *,
         precision="float32",
         termination_size=256,
         n=None,
         sort_eigenvalues=True):
    """ Computes the eigendecomposition of the symmetric/Hermitian matrix H.

  Args:
    H: The `n x n` Hermitian input, padded to `N x N`.
    precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
    termination_size: Recursion ends once the blocks reach this linear size.
    n: the true (dynamic) size of the matrix.
    sort_eigenvalues: If `True`, the eigenvalues will be sorted from lowest to
      highest.
  Returns:
    vals: The `n` eigenvalues of `H`.
    vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector
      of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up
      to numerical error.
  """
    M, N = H.shape
    if M != N:
        raise TypeError(f"Input H of shape {H.shape} must be square.")

    if N <= termination_size:
        if n is not None:
            H = _mask(H, (n, n), jnp.eye(N, dtype=H.dtype))
        return lax_linalg.eigh_jacobi(H, sort_eigenvalues=sort_eigenvalues)

    # TODO(phawkins): consider rounding N up to a larger size to maximize reuse
    # between matrices.

    n = N if n is None else n
    with jax.default_matmul_precision(precision):
        eig_vals, eig_vecs = _eigh_work(H,
                                        n,
                                        termination_size=termination_size)
    eig_vals = _mask(jnp.real(eig_vals), (n, ), jnp.nan)
    if sort_eigenvalues:
        sort_idxs = jnp.argsort(eig_vals)
        eig_vals = eig_vals[sort_idxs]
        eig_vecs = eig_vecs[:, sort_idxs]
    return eig_vals, eig_vecs
Exemplo n.º 20
0
 def phi(X):
     l = jnp.tril(X)
     return l / (jnp._constant_like(X, 1) +
                 jnp.eye(X.shape[-1], dtype=X.dtype))
Exemplo n.º 21
0
def conv_general_dilated_patches(
    lhs: lax.Array,
    filter_shape: Sequence[int],
    window_strides: Sequence[int],
    padding: Union[str, Sequence[Tuple[int, int]]],
    lhs_dilation: Optional[Sequence[int]] = None,
    rhs_dilation: Optional[Sequence[int]] = None,
    dimension_numbers: Optional[lax.ConvGeneralDilatedDimensionNumbers] = None,
    precision: Optional[lax.PrecisionType] = None,
    preferred_element_type: Optional[DType] = None,
) -> lax.Array:
    """Extract patches subject to the receptive field of `conv_general_dilated`.

  Runs the input through a convolution with given parameters. The kernel of the
  convolution is constructed such that the output channel dimension `"C"`
  contains flattened image patches, so instead a single `"C"` dimension
  represents, for example, three dimensions `"chw"` collapsed. The order of
  these dimensions is `"c" + ''.join(c for c in rhs_spec if c not in 'OI')`,
  where `rhs_spec == dimension_numbers[1]`, and the size of this `"C"`
  dimension is therefore the size of each patch, i.e.
  `np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`, where
  `lhs_spec == dimension_numbers[0]`.

  Docstring below adapted from `jax.lax.conv_general_dilated`.

  See Also:
    https://www.tensorflow.org/xla/operation_semantics#conv_convolution

  Args:
    lhs: a rank `n+2` dimensional input array.
    filter_shape: a sequence of `n` integers, representing the receptive window
      spatial shape in the order as specified in
      `rhs_spec = dimension_numbers[1]`.
    window_strides: a sequence of `n` integers, representing the inter-window
      strides.
    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
      `n` `(low, high)` integer pairs that give the padding to apply before and
      after each spatial dimension.
    lhs_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
      is also known as transposed convolution.
    rhs_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
      is also known as atrous convolution.
    dimension_numbers: either `None`, or a 3-tuple
      `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
      of length `n+2`. `None` defaults to `("NCHWD..., OIHWD..., NCHWD...")`.
    precision: Optional. Either ``None``, which means the default precision for
      the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
      ``Precision.HIGH`` or ``Precision.HIGHEST``).
    preferred_element_type: Optional. Either ``None``, which means the default
      accumulation type for the input types, or a datatype, indicating to
      accumulate results to and return a result with that datatype.

  Returns:
    A rank `n+2` array containing the flattened image patches in the output
    channel (`"C"`) dimension. For example if
    `dimension_numbers = ("NcHW", "OIwh", "CNHW")`, the output has dimension
    numbers `"CNHW" = "{cwh}NHW"`, with the size of dimension `"C"` equal to
    the size of each patch
    (`np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`).

  """
    filter_shape = tuple(filter_shape)
    dimension_numbers = lax.conv_dimension_numbers(lhs.shape,
                                                   (1, 1) + filter_shape,
                                                   dimension_numbers)

    lhs_spec, rhs_spec, out_spec = dimension_numbers

    spatial_size = prod(filter_shape)
    n_channels = lhs.shape[lhs_spec[1]]

    # Move separate `lhs` spatial locations into separate `rhs` channels.
    rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2)

    rhs = rhs.reshape((spatial_size, 1) + filter_shape)
    rhs = jnp.tile(rhs, (n_channels, ) + (1, ) * (rhs.ndim - 1))
    rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1]))

    out = lax.conv_general_dilated(
        lhs=lhs,
        rhs=rhs,
        window_strides=window_strides,
        padding=padding,
        lhs_dilation=lhs_dilation,
        rhs_dilation=rhs_dilation,
        dimension_numbers=dimension_numbers,
        precision=None if precision is None else
        (precision, lax.Precision.DEFAULT),
        feature_group_count=n_channels,
        preferred_element_type=preferred_element_type)
    return out
Exemplo n.º 22
0
Arquivo: eigh.py Projeto: cloudhan/jax
def _eigh_work(H, n, termination_size=256):
  """ The main work loop performing the symmetric eigendecomposition of H.
  Each step recursively computes a projector into the space of eigenvalues
  above jnp.mean(jnp.diag(H)). The result of the projections into and out of
  that space, along with the isometries accomplishing these, are then computed.
  This is performed recursively until the projections have size 1, and thus
  store an eigenvalue of the original input; the corresponding isometry is
  the related eigenvector. The results are then composed.

  This function cannot be Jitted because the internal split_spectrum cannot
  be.

  Args:
    H: The Hermitian input.
    n: The true (dynamic) shape of H.

  Returns:
    H, V: The result of the projection.
  """
  # We turn what was originally a recursive algorithm into an iterative
  # algorithm with an explicit stack.
  N, _ = H.shape
  n = jnp.asarray(n, jnp.int32)
  agenda = Stack.create(
    N + 1, _Subproblem(jnp.array(0, jnp.int32), jnp.array(0, jnp.int32)))
  agenda = agenda.push(_Subproblem(offset=jnp.int32(0), size=n))

  # eigenvectors is the array in which we build the output eigenvectors.
  # We initialize it with the identity matrix so the initial matrix
  # multiplications in_split_spectrum_jittable are the identity.
  eigenvectors = jnp.eye(N, dtype=H.dtype)

  # blocks is an array representing a stack of Hermitian matrix blocks that we
  # need to recursively decompose. Subproblems are different sizes, so the stack
  # of blocks is ragged. Subproblems are left-aligned (i.e. starting at the 0th
  # column). Here is an ASCII art picture of three blocks A, B, C, embedded
  # in the larger `blocks` workspace (represented with trailing dots).
  #
  # A A A . . .
  # A A A . . .
  # A A A . . .
  # B B . . . .
  # B B . . . .
  # C C C C . .
  # C C C C . .
  # C C C C . .
  # C C C C . .
  #
  # Each step of the algorithm subdivides a block into two subblocks whose
  # sizes sum to the original block size. We overwrite the original block with
  # those two subblocks so we don't need any additional scratch space.
  #
  # At termination, "blocks" will contain 1x1 blocks (i.e., the eigenvalues) in
  # its first column.
  blocks = H

  def base_case(B, offset, b, agenda, blocks, eigenvectors):
    # Base case: for blocks under a minimum size, we cutoff the recursion
    # and call the TPU Jacobi eigendecomposition implementation. The Jacobi
    # algorithm works well for small matrices but scales poorly, so the two
    # complement each other well.
    H = _slice(blocks, (offset, 0), (b, b), (B, B))
    V = _slice(eigenvectors, (0, offset), (n, b), (N, B))

    # We replace the masked-out part of the matrix with the identity matrix.
    # We know that the TPU Jacobi eigh implementation will not alter the order
    # of the eigenvalues, so we know the eigendecomposition of the original
    # matrix is in the top-left corner of the eigendecomposition of the padded
    # matrix.
    # It is very important that the underlying eigh implementation does not sort
    # the eigenvalues for this reason! This is currently not true of JAX's CPU
    # and GPU eigendecompositions, and for those platforms this algorithm will
    # only do the right thing if termination_size == 1.
    H = _mask(H, (b, b), jnp.eye(B, dtype=H.dtype))
    eig_vecs, eig_vals = lax.linalg.eigh(H, sort_eigenvalues=False)
    eig_vecs = _mask(eig_vecs, (b, b))
    eig_vals = _mask(eig_vals, (b,))
    eig_vecs = jnp.dot(V, eig_vecs)

    blocks = _update_slice(blocks, eig_vals[:, None], (offset, 0), (b, b))
    eigenvectors = _update_slice(eigenvectors, eig_vecs, (0, offset), (n, b))
    return agenda, blocks, eigenvectors

  def recursive_case(B, offset, b, agenda, blocks, eigenvectors):
    # The recursive case of the algorithm, specialized to a static block size
    # of B.
    H = _slice(blocks, (offset, 0), (b, b), (B, B))
    V = _slice(eigenvectors, (0, offset), (n, b), (N, B))

    split_point = jnp.nanmedian(_mask(jnp.diag(H), (b,), jnp.nan))  # TODO: Improve this?
    H_minus, V_minus, H_plus, V_plus, rank = split_spectrum(H, b, split_point, V0=V)

    blocks = _update_slice(blocks, H_minus, (offset, 0), (rank, rank))
    blocks = _update_slice(blocks, H_plus, (offset + rank, 0), (b - rank, b - rank))
    eigenvectors = _update_slice(eigenvectors, V_minus, (0, offset), (n, rank))
    eigenvectors = _update_slice(eigenvectors, V_plus, (0, offset + rank),
                                 (n, b - rank))

    agenda = agenda.push(_Subproblem(offset + rank, (b - rank)))
    agenda = agenda.push(_Subproblem(offset, rank))
    return agenda, blocks, eigenvectors

  def loop_cond(state):
    agenda, _, _ = state
    return ~agenda.empty()

  # It would be wasteful to perform all computation padded up to the original
  # matrix size. Instead, we form buckets of padded sizes e.g.,
  # [256, 512, 1024, ..., N], aiming for a balance between compilation time
  # and runtime.
  cutoff = min(N, termination_size)
  buckets = [cutoff]
  branches = [partial(base_case, cutoff)]
  i = cutoff
  while i < N:
    i = min(2 * i, N)
    buckets.append(i)
    branches.append(partial(recursive_case, i))
  buckets = jnp.array(buckets)

  def loop_body(state):
    agenda, blocks, eigenvectors = state
    (offset, b), agenda = agenda.pop()

    which = jnp.where(buckets < b, jnp.iinfo(jnp.int32).max, buckets)
    choice = jnp.argmin(which)
    return lax.switch(choice, branches, offset, b, agenda, blocks, eigenvectors)

  _, blocks, eigenvectors = lax.while_loop(
      loop_cond, loop_body, (agenda, blocks, eigenvectors))
  return blocks[:, 0], eigenvectors