Exemplo n.º 1
0
def _lu_jvp_rule(primals, tangents):
    a, = primals
    a_dot, = tangents
    lu, pivots = lu_p.bind(a)

    if a_dot is ad_util.zero:
        return (core.pack(
            (lu, pivots)), ad.TangentTuple((ad_util.zero, ad_util.zero)))

    a_shape = np.shape(a)
    m, n = a_shape[-2:]
    dtype = lax.dtype(a)
    k = min(m, n)

    permutation = lu_pivots_to_permutation(pivots, m)
    batch_dims = a_shape[:-2]
    iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims + (1, )))
    x = a_dot[iotas[:-1] + (permutation, slice(None))]

    # Differentiation of Matrix Functionals Using Triangular Factorization
    # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
    #
    #     LU = A
    # ==> L'U + LU' = A'
    # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
    # ==> L' = L . tril(inv(L) . A' . inv(U), -1)
    #     U' = triu(inv(L) . A' . inv(U)) . U

    ndims = len(a_shape)
    l_padding = [(0, 0, 0)] * ndims
    l_padding[-1] = (0, m - k, 0)
    zero = np._constant_like(lu, 0)
    l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding)
    l = l + np.eye(m, m, dtype=dtype)

    u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero,
                    ((k, 0, 0), (k, 0, 0)))
    u_padding = [(0, 0, 0)] * ndims
    u_padding[-2] = (0, n - k, 0)
    u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye

    la = triangular_solve(l,
                          x,
                          left_side=True,
                          transpose_a=False,
                          lower=True,
                          unit_diagonal=True)
    lau = triangular_solve(u,
                           la,
                           left_side=False,
                           transpose_a=False,
                           lower=False)

    l_dot = np.matmul(l, np.tril(lau, -1))
    u_dot = np.matmul(np.triu(lau), u)
    lu_dot = l_dot + u_dot
    return core.pack((lu, pivots)), ad.TangentTuple((lu_dot, ad_util.zero))
Exemplo n.º 2
0
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
  A, = primals
  dA, = tangents
  s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)

  if dA is ad_util.zero:
    return (core.pack((s, U, Vt)),
            ad.TangentTuple(ad_util.zero, ad_util.zero, ad_util.zero))

  if full_matrices:
    # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
    raise NotImplementedError(
      "Singular value decomposition JVP not implemented for full matrices")

  k = s.shape[-1]
  Ut, V = np.conj(U).T, np.conj(Vt).T
  s_dim = s[..., None, :]
  dS = np.dot(np.dot(Ut, dA), V)
  ds = np.real(np.diag(dS))
  F = 1 / (np.square(s_dim) - np.square(s_dim.T) + np.eye(k)) - np.eye(k)
  dSS = s_dim * dS
  SdS = s_dim.T * dS
  dU = np.dot(U, F * (dSS + dSS.T))
  dV = np.dot(V, F * (SdS + SdS.T))

  m, n = A.shape[-2], A.shape[-1]
  if m > n:
    dU = dU + np.dot(np.eye(m) - np.dot(U, Ut), np.dot(dA, V)) / s_dim
  if n > m:
    dV = dV + np.dot(np.eye(n) - np.dot(V, Vt), np.dot(np.conj(dA).T, U)) / s_dim
  return (s, U, Vt), (ds, dU, dV.T)
Exemplo n.º 3
0
def eigh_jvp_rule(primals, tangents, lower):
  # Derivative for eigh in the simplest case of distinct eigenvalues.
  # This is classic nondegenerate perurbation theory, but also see
  # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
  # The general solution treating the case of degenerate eigenvalues is
  # considerably more complicated. Ambitious readers may refer to the general
  # methods below or refer to degenerate perturbation theory in physics.
  # https://www.win.tue.nl/analysis/reports/rana06-33.pdf and
  # https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf
  a, = primals
  a_dot, = tangents

  v, w = eigh_p.bind(symmetrize(a), lower=lower)

  if a_dot is ad_util.zero:
    return core.pack((v, w)), ad.TangentTuple(ad_util.zero, ad_util.zero)

  # for complex numbers we need eigenvalues to be full dtype of v, a:
  w = w.astype(a.dtype)
  eye_n = np.eye(a.shape[-1], dtype=a.dtype)
  # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
  Fmat = np.reciprocal(eye_n + w - w[..., np.newaxis]) - eye_n
  # eigh impl doesn't support batch dims, but future-proof the grad.
  dot = lax.dot if a.ndim == 2 else lax.batch_matmul
  vdag_adot_v = dot(dot(_H(v), a_dot), v)
  dv = dot(v, np.multiply(Fmat, vdag_adot_v))
  dw = np.diagonal(vdag_adot_v)
  return (v, w), (dv, dw)
Exemplo n.º 4
0
def _scan_jvp(primals, tangents, forward, length, jaxpr):
    consts, init, xs = primals
    consts_dot, init_dot, xs_dot = tangents
    consts_aval, carry_aval, x_aval = jaxpr.in_avals
    _, y_aval = jaxpr.out_aval

    consts_nonzeros = ad.get_nonzeros(consts_dot)
    init_nonzeros = ad.get_nonzeros(init_dot)
    xs_nonzeros = ad.get_nonzeros(xs_dot)  # same as x_nonzeros b/c arrays

    carry_nonzeros = init_nonzeros
    for _ in range(1000):
        nonzeros = (consts_nonzeros, carry_nonzeros, xs_nonzeros)
        jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(jaxpr,
                                               nonzeros,
                                               instantiate=(carry_nonzeros,
                                                            False))
        carry_nonzeros_out, ys_nonzeros = nonzeros_out
        if carry_nonzeros_out == carry_nonzeros:
            break
        else:
            carry_nonzeros = _binary_lattice_join(carry_nonzeros_out,
                                                  carry_nonzeros)
    else:
        raise FixedPointError

    # convert_zeros is like strip_zeros but uses explicit lattice information to
    # instantiate zeros in some cases, namely in init_dot based on the fixed point
    nonzero_init_dot = _convert_zeros(carry_nonzeros, init, init_dot)
    nonzero_consts_dot = _convert_zeros(consts_nonzeros, consts, consts_dot)
    nonzero_xs_dot = _convert_zeros(xs_nonzeros, xs, xs_dot)

    consts_dual = core.pack((consts, nonzero_consts_dot))
    init_dual = core.pack((init, nonzero_init_dot))
    xs_dual = core.pack((xs, nonzero_xs_dot))

    carry_out_dual, ys_dual = scan_p.bind(consts_dual,
                                          init_dual,
                                          xs_dual,
                                          forward=forward,
                                          length=length,
                                          jaxpr=jaxpr_jvp)

    ys, ys_dot = ys_dual
    ys_dot = ad.put_zeros(ad.TangentTuple, ys_nonzeros, ys_dot)

    carry_out, carry_out_dot = carry_out_dual
    carry_out_dot = ad.put_zeros(ad.TangentTuple, carry_nonzeros_out,
                                 carry_out_dot)
    return core.pack((carry_out, ys)), ad.TangentTuple((carry_out_dot, ys_dot))
Exemplo n.º 5
0
def lu_jvp_rule(primals, tangents):
    a, = primals
    a_dot, = tangents
    lu, pivots = lu_p.bind(a)

    a_shape = np.shape(a)
    m, n = a_shape[-2:]
    dtype = lax._dtype(a)
    k = min(m, n)

    # TODO(phawkins): use a gather rather than a matrix multiplication here.
    permutation = lu_pivots_to_permutation(pivots, m)
    p = np.array(permutation[:, None] == np.arange(m), dtype=dtype)
    x = np.matmul(p, a_dot)

    # Differentiation of Matrix Functionals Using Triangular Factorization
    # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
    #
    #     LU = A
    # ==> L'U + LU' = A'
    # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
    # ==> L' = L . tril(inv(L) . A' . inv(U), -1)
    #     U' = triu(inv(L) . A' . inv(U)) . U

    ndims = len(a_shape)
    l_padding = [(0, 0, 0)] * ndims
    l_padding[-1] = (0, m - k, 0)
    zero = np._constant_like(lu, 0)
    l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding)
    l = l + np.eye(m, m, dtype=dtype)

    u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero,
                    ((k, 0, 0), (k, 0, 0)))
    u_padding = [(0, 0, 0)] * ndims
    u_padding[-2] = (0, n - k, 0)
    u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye

    la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True)
    lau = triangular_solve(u,
                           la,
                           left_side=False,
                           transpose_a=False,
                           lower=False)

    l_dot = np.matmul(l, np.tril(lau, -1))
    u_dot = np.matmul(np.triu(lau), u)
    lu_dot = l_dot + u_dot
    return core.pack((lu, pivots)), ad.TangentTuple((lu_dot, ad_util.zero))
Exemplo n.º 6
0
def qr_jvp_rule(primals, tangents, full_matrices):
  # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
  x, = primals
  dx, = tangents
  q, r = qr_p.bind(x, full_matrices=False)
  if dx is ad_util.zero:
    return core.pack((q, r)), ad.TangentTuple(ad_util.zero, ad_util.zero)
  if full_matrices or np.shape(x)[-2] < np.shape(x)[-1]:
    raise NotImplementedError
  dx_rinv = triangular_solve(r, dx)  # Right side solve by default
  qt_dx_rinv = np.matmul(_T(q), dx_rinv)
  qt_dx_rinv_lower = np.tril(qt_dx_rinv, -1)
  domega = qt_dx_rinv_lower - _T(qt_dx_rinv_lower)  # This is skew-symmetric
  dq = np.matmul(q, domega - qt_dx_rinv) + dx_rinv
  dr = np.matmul(qt_dx_rinv - domega, r)
  return (q, r), (dq, dr)