Beispiel #1
0
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
    A, = primals
    dA, = tangents
    s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)

    if dA is ad_util.zero:
        return (core.pack((s, U, Vt)),
                ad.TangentTuple(ad_util.zero, ad_util.zero, ad_util.zero))

    if full_matrices:
        # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
        raise NotImplementedError(
            "Singular value decomposition JVP not implemented for full matrices"
        )

    k = s.shape[-1]
    Ut, V = np.conj(U).T, np.conj(Vt).T
    s_dim = s[..., None, :]
    dS = np.dot(np.dot(Ut, dA), V)
    ds = np.real(np.diag(dS))
    F = 1 / (np.square(s_dim) - np.square(s_dim.T) + np.eye(k)) - np.eye(k)
    dSS = s_dim * dS
    SdS = s_dim.T * dS
    dU = np.dot(U, F * (dSS + dSS.T))
    dV = np.dot(V, F * (SdS + SdS.T))

    m, n = A.shape[-2], A.shape[-1]
    if m > n:
        dU = dU + np.dot(np.eye(m) - np.dot(U, Ut), np.dot(dA, V)) / s_dim
    if n > m:
        dV = dV + np.dot(np.eye(n) - np.dot(V, Vt), np.dot(np.conj(dA).T,
                                                           U)) / s_dim
    return core.pack((s, U, Vt)), core.pack((ds, dU, dV.T))
Beispiel #2
0
    def body(k, state):
        pivot, perm, a, error = state
        m_idx = np.arange(m)
        n_idx = np.arange(n)

        if np.issubdtype(a.dtype, np.complexfloating):
            t = a[:, k]
            magnitude = np.abs(np.real(t)) + np.abs(np.imag(t))
        else:
            magnitude = np.abs(a[:, k])
        i = np.argmax(np.where(m_idx >= k, magnitude, -np.inf))
        pivot = ops.index_update(pivot, ops.index[k], i)

        a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ])

        perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ])

        # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
        x = a[k, k]
        error = error | lax.eq(x, np._constant_like(a, 0))
        a = ops.index_update(a, ops.index[:, k],
                             np.where(m_idx > k, a[:, k] / x, a[:, k]))

        # a[k+1:, k+1:] -= np.outer(a[k+1:, k], a[k, k+1:])
        a = a - np.where(
            (m_idx[:, None] > k) & (n_idx > k), np.outer(a[:, k], a[k, :]),
            np.array(0, dtype=a.dtype))
        return pivot, perm, a, error
Beispiel #3
0
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
    A, = primals
    dA, = tangents
    s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)

    if compute_uv and full_matrices:
        # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
        raise NotImplementedError(
            "Singular value decomposition JVP not implemented for full matrices"
        )

    k = s.shape[-1]
    Ut, V = _H(U), _H(Vt)
    s_dim = s[..., None, :]
    dS = np.matmul(np.matmul(Ut, dA), V)
    ds = np.real(np.diagonal(dS, 0, -2, -1))
    F = 1 / (np.square(s_dim) - np.square(_T(s_dim)) + np.eye(k)) - np.eye(k)
    dSS = s_dim * dS
    SdS = _T(s_dim) * dS
    dU = np.matmul(U, F * (dSS + _T(dSS)))
    dV = np.matmul(V, F * (SdS + _T(SdS)))

    m, n = A.shape[-2], A.shape[-1]
    if m > n:
        dU = dU + np.matmul(np.eye(m) - np.matmul(U, Ut), np.matmul(dA,
                                                                    V)) / s_dim
    if n > m:
        dV = dV + np.matmul(
            np.eye(n) - np.matmul(V, Vt), np.matmul(_H(dA), U)) / s_dim
    return (s, U, Vt), (ds, dU, _T(dV))
Beispiel #4
0
    def body(k, state):
        pivot, perm, a = state
        m_idx = jnp.arange(m)
        n_idx = jnp.arange(n)

        if jnp.issubdtype(a.dtype, jnp.complexfloating):
            t = a[:, k]
            magnitude = jnp.abs(jnp.real(t)) + jnp.abs(jnp.imag(t))
        else:
            magnitude = jnp.abs(a[:, k])
        i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf))
        pivot = ops.index_update(pivot, ops.index[k], i)

        a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ])

        perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ])

        # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
        x = a[k, k]
        a = ops.index_update(a, ops.index[:, k],
                             jnp.where(m_idx > k, a[:, k] / x, a[:, k]))

        # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:])
        a = a - jnp.where(
            (m_idx[:, None] > k) & (n_idx > k), jnp.outer(a[:, k], a[k, :]),
            jnp.array(0, dtype=a.dtype))
        return pivot, perm, a
Beispiel #5
0
def eigh_jvp_rule(primals, tangents, lower):
    # Derivative for eigh in the simplest case of distinct eigenvalues.
    # This is classic nondegenerate perurbation theory, but also see
    # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
    # The general solution treating the case of degenerate eigenvalues is
    # considerably more complicated. Ambitious readers may refer to the general
    # methods below or refer to degenerate perturbation theory in physics.
    # https://www.win.tue.nl/analysis/reports/rana06-33.pdf and
    # https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf
    a, = primals
    a_dot, = tangents

    v, w_real = eigh_p.bind(symmetrize(a), lower=lower)

    # for complex numbers we need eigenvalues to be full dtype of v, a:
    w = w_real.astype(a.dtype)
    eye_n = jnp.eye(a.shape[-1], dtype=a.dtype)
    # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
    Fmat = jnp.reciprocal(eye_n + w[..., jnp.newaxis, :] -
                          w[..., jnp.newaxis]) - eye_n
    # eigh impl doesn't support batch dims, but future-proof the grad.
    dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul,
                  precision=lax.Precision.HIGHEST)
    vdag_adot_v = dot(dot(_H(v), a_dot), v)
    dv = dot(v, jnp.multiply(Fmat, vdag_adot_v))
    dw = jnp.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1))
    return (v, w_real), (dv, dw)
Beispiel #6
0
def qr_jvp_rule(primals, tangents, full_matrices):
    # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
    x, = primals
    dx, = tangents
    q, r = qr_p.bind(x, full_matrices=False)
    *_, m, n = x.shape
    if full_matrices or m < n:
        raise NotImplementedError(
            "Unimplemented case of QR decomposition derivative")
    dx_rinv = triangular_solve(r, dx)  # Right side solve by default
    qt_dx_rinv = jnp.matmul(_H(q), dx_rinv)
    qt_dx_rinv_lower = jnp.tril(qt_dx_rinv, -1)
    do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower)  # This is skew-symmetric
    # The following correction is necessary for complex inputs
    do = do + jnp.eye(n, dtype=do.dtype) * (qt_dx_rinv - jnp.real(qt_dx_rinv))
    dq = jnp.matmul(q, do - qt_dx_rinv) + dx_rinv
    dr = jnp.matmul(qt_dx_rinv - do, r)
    return (q, r), (dq, dr)