Exemplo n.º 1
0
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
    if full_matrices:
        #TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
        raise NotImplementedError(
            "Singular value decomposition JVP not implemented for full matrices"
        )

    A, = primals
    dA, = tangents
    s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)

    k = s.shape[-1]
    Ut, V = np.conj(U).T, np.conj(Vt).T
    s_dim = s[..., None, :]
    dS = np.dot(np.dot(Ut, dA), V)
    ds = np.real(np.diag(dS))
    F = 1 / (np.square(s_dim) - np.square(s_dim.T) + np.eye(k)) - np.eye(k)
    dSS = s_dim * dS
    SdS = s_dim.T * dS
    dU = np.dot(U, F * (dSS + dSS.T))
    dV = np.dot(V, F * (SdS + SdS.T))

    m, n = A.shape[-2], A.shape[-1]
    if m > n:
        dU = dU + np.dot(np.eye(m) - np.dot(U, Ut), np.dot(dA, V)) / s_dim
    if n > m:
        dV = dV + np.dot(np.eye(n) - np.dot(V, Vt), np.dot(np.conj(dA).T,
                                                           U)) / s_dim
    return core.pack((s, U, Vt)), core.pack((ds, dU, dV.T))
Exemplo n.º 2
0
def triangular_solve_jvp_rule_a(g_a, ans, a, b, left_side, lower, transpose_a,
                                conjugate_a, unit_diagonal):
    m, n = b.shape[-2:]
    k = 1 if unit_diagonal else 0
    g_a = np.tril(g_a, k=-k) if lower else np.triu(g_a, k=k)
    g_a = lax.neg(g_a)
    g_a = np.swapaxes(g_a, -1, -2) if transpose_a else g_a
    g_a = np.conj(g_a) if conjugate_a else g_a
    dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul,
                  precision=lax.Precision.HIGHEST)

    def a_inverse(rhs):
        return triangular_solve(a, rhs, left_side, lower, transpose_a,
                                conjugate_a, unit_diagonal)

    # triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs
    # for matrix/vector inputs). Order these operations in whichever order is
    # cheaper.
    if left_side:
        assert g_a.shape[-2:] == a.shape[-2:] == (m, m) and ans.shape[-2:] == (
            m, n)
        if m > n:
            return a_inverse(dot(g_a, ans))  # A^{-1} (∂A X)
        else:
            return dot(a_inverse(g_a), ans)  # (A^{-1} ∂A) X
    else:
        assert g_a.shape[-2:] == a.shape[-2:] == (n, n) and ans.shape[-2:] == (
            m, n)
        if m < n:
            return a_inverse(dot(ans, g_a))  # (X ∂A) A^{-1}
        else:
            return dot(ans, a_inverse(g_a))  # X (∂A A^{-1})
Exemplo n.º 3
0
def triangular_solve_jvp_rule_a(
    g_a, ans, a, b, left_side, lower, transpose_a, conjugate_a):
  g_a = lax.neg(g_a)
  g_a = np.swapaxes(g_a, -1, -2) if transpose_a else g_a
  g_a = np.conj(g_a) if conjugate_a else g_a
  tmp = triangular_solve(a, g_a, left_side, lower, transpose_a, conjugate_a)
  dot = lax.dot if g_a.ndim == 2 else lax.batch_matmul
  if left_side:
    return dot(tmp, ans)
  else:
    return dot(ans, tmp)
Exemplo n.º 4
0
def triangular_solve_jvp_rule_a(g_a, ans, a, b, left_side, lower, transpose_a,
                                conjugate_a, unit_diagonal):
    k = 1 if unit_diagonal else 0
    g_a = np.tril(g_a, k=-k) if lower else np.triu(g_a, k=k)
    g_a = lax.neg(g_a)
    g_a = np.swapaxes(g_a, -1, -2) if transpose_a else g_a
    g_a = np.conj(g_a) if conjugate_a else g_a
    tmp = triangular_solve(a, g_a, left_side, lower, transpose_a, conjugate_a,
                           unit_diagonal)
    dot = lax.dot if g_a.ndim == 2 else lax.batch_matmul
    if left_side:
        return dot(tmp, ans)
    else:
        return dot(ans, tmp)
Exemplo n.º 5
0
def _H(x):
    return np.conj(_T(x))
Exemplo n.º 6
0
def _H(x): return np.conj(_T(x))
def symmetrize(x): return (x + _H(x)) / 2