Exemple #1
0
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
  A, = primals
  dA, = tangents
  s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)

  if compute_uv and full_matrices:
    # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
    raise NotImplementedError(
      "Singular value decomposition JVP not implemented for full matrices")

  k = s.shape[-1]
  Ut, V = _H(U), _H(Vt)
  s_dim = s[..., None, :]
  dS = jnp.matmul(jnp.matmul(Ut, dA), V)
  ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))

  if not compute_uv:
    return (s,), (ds,)

  F = 1 / (jnp.square(s_dim) - jnp.square(_T(s_dim)) + jnp.eye(k, dtype=A.dtype))
  F = F - jnp.eye(k, dtype=A.dtype)
  dSS = s_dim * dS
  SdS = _T(s_dim) * dS
  dU = jnp.matmul(U, F * (dSS + _T(dSS)))
  dV = jnp.matmul(V, F * (SdS + _T(SdS)))

  m, n = A.shape[-2:]
  if m > n:
    dU = dU + jnp.matmul(jnp.eye(m, dtype=A.dtype) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim
  if n > m:
    dV = dV + jnp.matmul(jnp.eye(n, dtype=A.dtype) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim
  return (s, U, Vt), (ds, dU, _T(dV))
Exemple #2
0
def _lu_jvp_rule(primals, tangents):
    a, = primals
    a_dot, = tangents
    lu, pivots, permutation = lu_p.bind(a)

    a_shape = jnp.shape(a)
    m, n = a_shape[-2:]
    dtype = lax.dtype(a)
    k = min(m, n)

    batch_dims = a_shape[:-2]
    iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, )))
    x = a_dot[iotas[:-1] + (permutation, slice(None))]

    # Differentiation of Matrix Functionals Using Triangular Factorization
    # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
    #
    #     LU = A
    # ==> L'U + LU' = A'
    # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
    # ==> L' = L . tril(inv(L) . A' . inv(U), -1)
    #     U' = triu(inv(L) . A' . inv(U)) . U

    ndims = len(a_shape)
    l_padding = [(0, 0, 0)] * ndims
    l_padding[-1] = (0, m - k, 0)
    zero = jnp._constant_like(lu, 0)
    l = lax.pad(jnp.tril(lu[..., :, :k], -1), zero, l_padding)
    l = l + jnp.eye(m, m, dtype=dtype)

    u_eye = lax.pad(jnp.eye(n - k, n - k, dtype=dtype), zero,
                    ((k, 0, 0), (k, 0, 0)))
    u_padding = [(0, 0, 0)] * ndims
    u_padding[-2] = (0, n - k, 0)
    u = lax.pad(jnp.triu(lu[..., :k, :]), zero, u_padding) + u_eye

    la = triangular_solve(l,
                          x,
                          left_side=True,
                          transpose_a=False,
                          lower=True,
                          unit_diagonal=True)
    lau = triangular_solve(u,
                           la,
                           left_side=False,
                           transpose_a=False,
                           lower=False)

    l_dot = jnp.matmul(l, jnp.tril(lau, -1))
    u_dot = jnp.matmul(jnp.triu(lau), u)
    lu_dot = l_dot + u_dot
    return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_value(pivots),
                                       ad_util.Zero.from_value(permutation))
Exemple #3
0
def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=None):
    if overwrite_data is not None:
        raise NotImplementedError("overwrite_data argument not implemented.")
    if type not in ['constant', 'linear']:
        raise ValueError("Trend type must be 'linear' or 'constant'.")
    data, = _promote_dtypes_inexact(jnp.asarray(data))
    if type == 'constant':
        return data - data.mean(axis, keepdims=True)
    else:
        N = data.shape[axis]
        # bp is static, so we use np operations to avoid pushing to device.
        bp = np.sort(np.unique(np.r_[0, bp, N]))
        if bp[0] < 0 or bp[-1] > N:
            raise ValueError(
                "Breakpoints must be non-negative and less than length of data along given axis."
            )
        data = jnp.moveaxis(data, axis, 0)
        shape = data.shape
        data = data.reshape(N, -1)
        for m in range(len(bp) - 1):
            Npts = bp[m + 1] - bp[m]
            A = jnp.vstack([
                jnp.ones(Npts, dtype=data.dtype),
                jnp.arange(1, Npts + 1, dtype=data.dtype) / Npts
            ]).T
            sl = slice(bp[m], bp[m + 1])
            coef, *_ = linalg.lstsq(A, data[sl])
            data = data.at[sl].add(
                -jnp.matmul(A, coef, precision=lax.Precision.HIGHEST))
        return jnp.moveaxis(data.reshape(shape), 0, axis)
Exemple #4
0
def qr_jvp_rule(primals, tangents, full_matrices):
    # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
    x, = primals
    dx, = tangents
    q, r = qr_p.bind(x, full_matrices=False)
    *_, m, n = x.shape
    if full_matrices or m < n:
        raise NotImplementedError(
            "Unimplemented case of QR decomposition derivative")
    dx_rinv = triangular_solve(r, dx)  # Right side solve by default
    qt_dx_rinv = jnp.matmul(_H(q), dx_rinv)
    qt_dx_rinv_lower = jnp.tril(qt_dx_rinv, -1)
    do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower)  # This is skew-symmetric
    # The following correction is necessary for complex inputs
    do = do + jnp.eye(n, dtype=do.dtype) * (qt_dx_rinv - jnp.real(qt_dx_rinv))
    dq = jnp.matmul(q, do - qt_dx_rinv) + dx_rinv
    dr = jnp.matmul(qt_dx_rinv - do, r)
    return (q, r), (dq, dr)
Exemple #5
0
def _lstsq(a, b, rcond, *, numpy_resid=False):
    # TODO: add lstsq to lax_linalg and implement this function via those wrappers.
    # TODO: add custom jvp rule for more robust lstsq differentiation
    a, b = _promote_arg_dtypes(a, b)
    if a.shape[0] != b.shape[0]:
        raise ValueError("Leading dimensions of input arrays must match")
    b_orig_ndim = b.ndim
    if b_orig_ndim == 1:
        b = b[:, None]
    if a.ndim != 2:
        raise TypeError(
            f"{a.ndim}-dimensional array given. Array must be two-dimensional")
    if b.ndim != 2:
        raise TypeError(
            f"{b.ndim}-dimensional array given. Array must be one or two-dimensional"
        )
    m, n = a.shape
    dtype = a.dtype
    if rcond is None:
        rcond = jnp.finfo(dtype).eps * max(n, m)
    else:
        rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
    u, s, vt = svd(a, full_matrices=False)
    mask = s >= rcond * s[0]
    rank = mask.sum()
    safe_s = jnp.where(mask, s, 1)
    s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
    uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
    x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
    # Numpy returns empty residuals in some cases. To allow compilation, we
    # default to returning full residuals in all cases.
    if numpy_resid and (rank < n or m <= n):
        resid = jnp.asarray([])
    else:
        b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST)
        resid = norm(b - b_estimate, axis=0)**2
    if b_orig_ndim == 1:
        x = x.ravel()
    return x, resid, rank, s
Exemple #6
0
def _lu(a, permute_l):
    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    lu, pivots, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    m, n = jnp.shape(a)
    p = jnp.real(jnp.array(permutation == jnp.arange(m)[:, None], dtype=dtype))
    k = min(m, n)
    l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
    u = jnp.triu(lu)[:k, :]
    if permute_l:
        return jnp.matmul(p, l), u
    else:
        return p, l, u
Exemple #7
0
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
  A, = primals
  dA, = tangents
  s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)

  if compute_uv and full_matrices:
    # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
    raise NotImplementedError(
      "Singular value decomposition JVP not implemented for full matrices")

  Ut, V = _H(U), _H(Vt)
  s_dim = s[..., None, :]
  dS = jnp.matmul(jnp.matmul(Ut, dA), V)
  ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))

  if not compute_uv:
    return (s,), (ds,)

  s_diffs = jnp.square(s_dim) - jnp.square(_T(s_dim))
  s_diffs_zeros = jnp.eye(s.shape[-1], dtype=A.dtype)  # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.)  # is 1. where s_diffs is 0. and is 0. everywhere else
  F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros
  dSS = s_dim * dS  # dS.dot(jnp.diag(s))
  SdS = _T(s_dim) * dS  # jnp.diag(s).dot(dS)

  s_zeros = jnp.ones((), dtype=A.dtype) * (s == 0.)
  s_inv = 1 / (s + s_zeros) - s_zeros
  s_inv_mat = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(s_inv)
  dUdV_diag = .5 * (dS - _H(dS)) * s_inv_mat
  dU = jnp.matmul(U, F * (dSS + _H(dSS)) + dUdV_diag)
  dV = jnp.matmul(V, F * (SdS + _H(SdS)))

  m, n = A.shape[-2:]
  if m > n:
    dU = dU + jnp.matmul(jnp.eye(m, dtype=A.dtype) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim
  if n > m:
    dV = dV + jnp.matmul(jnp.eye(n, dtype=A.dtype) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim

  return (s, U, Vt), (ds, dU, _H(dV))
Exemple #8
0
def pinv(a, rcond=None):
    # Uses same algorithm as
    # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
    a = jnp.conj(a)
    if rcond is None:
        max_rows_cols = max(a.shape[-2:])
        rcond = 10. * max_rows_cols * jnp.finfo(a.dtype).eps
    rcond = jnp.asarray(rcond)
    u, s, vh = svd(a, full_matrices=False)
    # Singular values less than or equal to ``rcond * largest_singular_value``
    # are set to zero.
    cutoff = rcond[..., jnp.newaxis] * jnp.amax(
        s, axis=-1, keepdims=True, initial=-jnp.inf)
    s = jnp.where(s > cutoff, s, jnp.inf)
    res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]))
    return lax.convert_element_type(res, a.dtype)
Exemple #9
0
def _lu(a, permute_l):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    lu, _, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    m, n = jnp.shape(a)
    p = jnp.real(
        jnp.array(permutation[None, :] == jnp.arange(
            m, dtype=permutation.dtype)[:, None],
                  dtype=dtype))
    k = min(m, n)
    l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
    u = jnp.triu(lu)[:k, :]
    if permute_l:
        return jnp.matmul(p, l), u
    else:
        return p, l, u
Exemple #10
0
def _sqrtm(A):
    T, Z = schur(A, output='complex')
    sqrt_T = _sqrtm_triu(T)
    return jnp.matmul(jnp.matmul(Z, sqrt_T, precision=lax.Precision.HIGHEST),
                      jnp.conj(Z.T),
                      precision=lax.Precision.HIGHEST)