Exemple #1
0
def _lu_jvp_rule(primals, tangents):
    a, = primals
    a_dot, = tangents
    lu, pivots = lu_p.bind(a)

    if a_dot is ad_util.zero:
        return (core.pack(
            (lu, pivots)), ad.TangentTuple((ad_util.zero, ad_util.zero)))

    a_shape = np.shape(a)
    m, n = a_shape[-2:]
    dtype = lax.dtype(a)
    k = min(m, n)

    permutation = lu_pivots_to_permutation(pivots, m)
    batch_dims = a_shape[:-2]
    iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims + (1, )))
    x = a_dot[iotas[:-1] + (permutation, slice(None))]

    # Differentiation of Matrix Functionals Using Triangular Factorization
    # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
    #
    #     LU = A
    # ==> L'U + LU' = A'
    # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
    # ==> L' = L . tril(inv(L) . A' . inv(U), -1)
    #     U' = triu(inv(L) . A' . inv(U)) . U

    ndims = len(a_shape)
    l_padding = [(0, 0, 0)] * ndims
    l_padding[-1] = (0, m - k, 0)
    zero = np._constant_like(lu, 0)
    l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding)
    l = l + np.eye(m, m, dtype=dtype)

    u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero,
                    ((k, 0, 0), (k, 0, 0)))
    u_padding = [(0, 0, 0)] * ndims
    u_padding[-2] = (0, n - k, 0)
    u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye

    la = triangular_solve(l,
                          x,
                          left_side=True,
                          transpose_a=False,
                          lower=True,
                          unit_diagonal=True)
    lau = triangular_solve(u,
                           la,
                           left_side=False,
                           transpose_a=False,
                           lower=False)

    l_dot = np.matmul(l, np.tril(lau, -1))
    u_dot = np.matmul(np.triu(lau), u)
    lu_dot = l_dot + u_dot
    return (lu, pivots), (lu_dot, ad_util.zero)
Exemple #2
0
def lu_jvp_rule(primals, tangents):
    a, = primals
    a_dot, = tangents
    lu, pivots = lu_p.bind(a)

    a_shape = np.shape(a)
    m, n = a_shape[-2:]
    dtype = lax._dtype(a)
    k = min(m, n)

    # TODO(phawkins): use a gather rather than a matrix multiplication here.
    permutation = lu_pivots_to_permutation(pivots, m)
    p = np.array(permutation[:, None] == np.arange(m), dtype=dtype)
    x = np.matmul(p, a_dot)

    # Differentiation of Matrix Functionals Using Triangular Factorization
    # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
    #
    #     LU = A
    # ==> L'U + LU' = A'
    # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
    # ==> L' = L . tril(inv(L) . A' . inv(U), -1)
    #     U' = triu(inv(L) . A' . inv(U)) . U

    ndims = len(a_shape)
    l_padding = [(0, 0, 0)] * ndims
    l_padding[-1] = (0, m - k, 0)
    zero = np._constant_like(lu, 0)
    l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding)
    l = l + np.eye(m, m, dtype=dtype)

    u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero,
                    ((k, 0, 0), (k, 0, 0)))
    u_padding = [(0, 0, 0)] * ndims
    u_padding[-2] = (0, n - k, 0)
    u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye

    la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True)
    lau = triangular_solve(u,
                           la,
                           left_side=False,
                           transpose_a=False,
                           lower=False)

    l_dot = np.matmul(l, np.tril(lau, -1))
    u_dot = np.matmul(np.triu(lau), u)
    lu_dot = l_dot + u_dot
    return core.pack((lu, pivots)), ad.TangentTuple((lu_dot, ad_util.zero))
Exemple #3
0
def triangular_solve_jvp_rule_a(g_a, ans, a, b, left_side, lower, transpose_a,
                                conjugate_a, unit_diagonal):
    m, n = b.shape[-2:]
    k = 1 if unit_diagonal else 0
    g_a = np.tril(g_a, k=-k) if lower else np.triu(g_a, k=k)
    g_a = lax.neg(g_a)
    g_a = np.swapaxes(g_a, -1, -2) if transpose_a else g_a
    g_a = np.conj(g_a) if conjugate_a else g_a
    dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul,
                  precision=lax.Precision.HIGHEST)

    def a_inverse(rhs):
        return triangular_solve(a, rhs, left_side, lower, transpose_a,
                                conjugate_a, unit_diagonal)

    # triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs
    # for matrix/vector inputs). Order these operations in whichever order is
    # cheaper.
    if left_side:
        assert g_a.shape[-2:] == a.shape[-2:] == (m, m) and ans.shape[-2:] == (
            m, n)
        if m > n:
            return a_inverse(dot(g_a, ans))  # A^{-1} (∂A X)
        else:
            return dot(a_inverse(g_a), ans)  # (A^{-1} ∂A) X
    else:
        assert g_a.shape[-2:] == a.shape[-2:] == (n, n) and ans.shape[-2:] == (
            m, n)
        if m < n:
            return a_inverse(dot(ans, g_a))  # (X ∂A) A^{-1}
        else:
            return dot(ans, a_inverse(g_a))  # X (∂A A^{-1})
Exemple #4
0
def cholesky_jvp_rule(primals, tangents):
    x, = primals
    sigma_dot, = tangents
    L = np.tril(cholesky_p.bind(x))

    # Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf
    def phi(X):
        l = np.tril(X)
        return l / (np._constant_like(X, 1) +
                    np.eye(X.shape[-1], dtype=X.dtype))

    tmp = triangular_solve(L,
                           sigma_dot,
                           left_side=False,
                           transpose_a=True,
                           conjugate_a=True,
                           lower=True)
    L_dot = lax.batch_matmul(L,
                             phi(
                                 triangular_solve(L,
                                                  tmp,
                                                  left_side=True,
                                                  transpose_a=False,
                                                  lower=True)),
                             precision=lax.Precision.HIGHEST)
    return L, L_dot
Exemple #5
0
def cholesky_jvp_rule(primals, tangents):
    x, = primals
    sigma_dot, = tangents
    L = np.tril(cholesky_p.bind(x))

    # Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf
    phi = lambda X: np.tril(X) / (1 + np.eye(X.shape[-1], dtype=X.dtype))
    tmp = triangular_solve(L,
                           sigma_dot,
                           left_side=False,
                           transpose_a=True,
                           lower=True)
    L_dot = lax.batch_matmul(
        L,
        phi(
            triangular_solve(L,
                             tmp,
                             left_side=True,
                             transpose_a=False,
                             lower=True)))
    return L, L_dot
Exemple #6
0
def cholesky_jvp_rule(primals, tangents):
  x, = primals
  sigma_dot, = tangents
  L = cholesky_p.bind(x)

  # Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf
  sigma_dot = (sigma_dot + _T(sigma_dot)) / 2
  phi = lambda X: np.tril(X) / (1 + np.eye(x.shape[-1]))
  tmp = triangular_solve(L, sigma_dot,
                         left_side=False, transpose_a=True, lower=True)
  L_dot = lax.dot(L, phi(triangular_solve(
      L, tmp, left_side=True, transpose_a=False, lower=True)))
  return L, L_dot
Exemple #7
0
def qr_jvp_rule(primals, tangents, full_matrices):
    # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
    x, = primals
    dx, = tangents
    q, r = qr_p.bind(x, full_matrices=False)
    if full_matrices or np.shape(x)[-2] < np.shape(x)[-1]:
        raise NotImplementedError
    dx_rinv = triangular_solve(r, dx)  # Right side solve by default
    qt_dx_rinv = np.matmul(_H(q), dx_rinv)
    qt_dx_rinv_lower = np.tril(qt_dx_rinv, -1)
    domega = qt_dx_rinv_lower - _H(qt_dx_rinv_lower)  # This is skew-symmetric
    dq = np.matmul(q, domega - qt_dx_rinv) + dx_rinv
    dr = np.matmul(qt_dx_rinv - domega, r)
    return (q, r), (dq, dr)
Exemple #8
0
def triangular_solve_jvp_rule_a(g_a, ans, a, b, left_side, lower, transpose_a,
                                conjugate_a, unit_diagonal):
    k = 1 if unit_diagonal else 0
    g_a = np.tril(g_a, k=-k) if lower else np.triu(g_a, k=k)
    g_a = lax.neg(g_a)
    g_a = np.swapaxes(g_a, -1, -2) if transpose_a else g_a
    g_a = np.conj(g_a) if conjugate_a else g_a
    tmp = triangular_solve(a, g_a, left_side, lower, transpose_a, conjugate_a,
                           unit_diagonal)
    dot = lax.dot if g_a.ndim == 2 else lax.batch_matmul
    if left_side:
        return dot(tmp, ans)
    else:
        return dot(ans, tmp)
Exemple #9
0
def qr_jvp_rule(primals, tangents, full_matrices):
    # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
    x, = primals
    dx, = tangents
    q, r = qr_p.bind(x, full_matrices=False)
    *_, m, n = x.shape
    if full_matrices or m < n:
        raise NotImplementedError(
            "Unimplemented case of QR decomposition derivative")
    dx_rinv = triangular_solve(r, dx)  # Right side solve by default
    qt_dx_rinv = jnp.matmul(_H(q), dx_rinv)
    qt_dx_rinv_lower = jnp.tril(qt_dx_rinv, -1)
    do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower)  # This is skew-symmetric
    # The following correction is necessary for complex inputs
    do = do + jnp.eye(n, dtype=do.dtype) * (qt_dx_rinv - jnp.real(qt_dx_rinv))
    dq = jnp.matmul(q, do - qt_dx_rinv) + dx_rinv
    dr = jnp.matmul(qt_dx_rinv - do, r)
    return (q, r), (dq, dr)
Exemple #10
0
def cholesky(x, symmetrize_input=True):
    if symmetrize_input:
        x = symmetrize(x)
    return np.tril(cholesky_p.bind(x))
Exemple #11
0
 def phi(X):
     l = np.tril(X)
     return l / (np._constant_like(X, 1) +
                 np.eye(X.shape[-1], dtype=X.dtype))