Exemplo n.º 1
0
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
    A, = primals
    dA, = tangents
    s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)

    if compute_uv and full_matrices:
        # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
        raise NotImplementedError(
            "Singular value decomposition JVP not implemented for full matrices"
        )

    k = s.shape[-1]
    Ut, V = _H(U), _H(Vt)
    s_dim = s[..., None, :]
    dS = np.matmul(np.matmul(Ut, dA), V)
    ds = np.real(np.diagonal(dS, 0, -2, -1))
    F = 1 / (np.square(s_dim) - np.square(_T(s_dim)) +
             np.eye(k, dtype=A.dtype))
    F = F - np.eye(k, dtype=A.dtype)
    dSS = s_dim * dS
    SdS = _T(s_dim) * dS
    dU = np.matmul(U, F * (dSS + _T(dSS)))
    dV = np.matmul(V, F * (SdS + _T(SdS)))

    m, n = A.shape[-2:]
    if m > n:
        dU = dU + np.matmul(
            np.eye(m, dtype=A.dtype) - np.matmul(U, Ut), np.matmul(dA,
                                                                   V)) / s_dim
    if n > m:
        dV = dV + np.matmul(
            np.eye(n, dtype=A.dtype) - np.matmul(V, Vt), np.matmul(_H(dA),
                                                                   U)) / s_dim
    return (s, U, Vt), (ds, dU, _T(dV))
Exemplo n.º 2
0
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
    if full_matrices:
        #TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
        raise NotImplementedError(
            "Singular value decomposition JVP not implemented for full matrices"
        )

    A, = primals
    dA, = tangents
    s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)

    k = s.shape[-1]
    Ut, V = np.conj(U).T, np.conj(Vt).T
    s_dim = s[..., None, :]
    dS = np.dot(np.dot(Ut, dA), V)
    ds = np.real(np.diag(dS))
    F = 1 / (np.square(s_dim) - np.square(s_dim.T) + np.eye(k)) - np.eye(k)
    dSS = s_dim * dS
    SdS = s_dim.T * dS
    dU = np.dot(U, F * (dSS + dSS.T))
    dV = np.dot(V, F * (SdS + SdS.T))

    m, n = A.shape[-2], A.shape[-1]
    if m > n:
        dU = dU + np.dot(np.eye(m) - np.dot(U, Ut), np.dot(dA, V)) / s_dim
    if n > m:
        dV = dV + np.dot(np.eye(n) - np.dot(V, Vt), np.dot(np.conj(dA).T,
                                                           U)) / s_dim
    return core.pack((s, U, Vt)), core.pack((ds, dU, dV.T))