Пример #1
0
def eigh_jvp_rule(primals, tangents, lower):
    # Derivative for eigh in the simplest case of distinct eigenvalues.
    # This is classic nondegenerate perurbation theory, but also see
    # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
    # The general solution treating the case of degenerate eigenvalues is
    # considerably more complicated. Ambitious readers may refer to the general
    # methods below or refer to degenerate perturbation theory in physics.
    # https://www.win.tue.nl/analysis/reports/rana06-33.pdf and
    # https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf
    a, = primals
    a_dot, = tangents

    v, w = eigh_p.bind(symmetrize(a), lower=lower)

    # for complex numbers we need eigenvalues to be full dtype of v, a:
    w = w.astype(a.dtype)
    eye_n = np.eye(a.shape[-1], dtype=a.dtype)
    # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
    Fmat = np.reciprocal(eye_n + w[..., np.newaxis, :] -
                         w[..., np.newaxis]) - eye_n
    # eigh impl doesn't support batch dims, but future-proof the grad.
    dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul,
                  precision=lax.Precision.HIGHEST)
    vdag_adot_v = dot(dot(_H(v), a_dot), v)
    dv = dot(v, np.multiply(Fmat, vdag_adot_v))
    dw = np.diagonal(vdag_adot_v, axis1=-2, axis2=-1)
    return (v, w), (dv, dw)
Пример #2
0
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
    A, = primals
    dA, = tangents
    s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)

    if compute_uv and full_matrices:
        # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
        raise NotImplementedError(
            "Singular value decomposition JVP not implemented for full matrices"
        )

    k = s.shape[-1]
    Ut, V = _H(U), _H(Vt)
    s_dim = s[..., None, :]
    dS = np.matmul(np.matmul(Ut, dA), V)
    ds = np.real(np.diagonal(dS, 0, -2, -1))
    F = 1 / (np.square(s_dim) - np.square(_T(s_dim)) +
             np.eye(k, dtype=A.dtype))
    F = F - np.eye(k, dtype=A.dtype)
    dSS = s_dim * dS
    SdS = _T(s_dim) * dS
    dU = np.matmul(U, F * (dSS + _T(dSS)))
    dV = np.matmul(V, F * (SdS + _T(SdS)))

    m, n = A.shape[-2:]
    if m > n:
        dU = dU + np.matmul(
            np.eye(m, dtype=A.dtype) - np.matmul(U, Ut), np.matmul(dA,
                                                                   V)) / s_dim
    if n > m:
        dV = dV + np.matmul(
            np.eye(n, dtype=A.dtype) - np.matmul(V, Vt), np.matmul(_H(dA),
                                                                   U)) / s_dim
    return (s, U, Vt), (ds, dU, _T(dV))