Esempio n. 1
0
def _lu_jvp_rule(primals, tangents):
    a, = primals
    a_dot, = tangents
    lu, pivots, permutation = lu_p.bind(a)

    a_shape = jnp.shape(a)
    m, n = a_shape[-2:]
    dtype = lax.dtype(a)
    k = min(m, n)

    batch_dims = a_shape[:-2]
    iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, )))
    x = a_dot[iotas[:-1] + (permutation, slice(None))]

    # Differentiation of Matrix Functionals Using Triangular Factorization
    # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
    #
    #     LU = A
    # ==> L'U + LU' = A'
    # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
    # ==> L' = L . tril(inv(L) . A' . inv(U), -1)
    #     U' = triu(inv(L) . A' . inv(U)) . U

    ndims = len(a_shape)
    l_padding = [(0, 0, 0)] * ndims
    l_padding[-1] = (0, m - k, 0)
    zero = jnp._constant_like(lu, 0)
    l = lax.pad(jnp.tril(lu[..., :, :k], -1), zero, l_padding)
    l = l + jnp.eye(m, m, dtype=dtype)

    u_eye = lax.pad(jnp.eye(n - k, n - k, dtype=dtype), zero,
                    ((k, 0, 0), (k, 0, 0)))
    u_padding = [(0, 0, 0)] * ndims
    u_padding[-2] = (0, n - k, 0)
    u = lax.pad(jnp.triu(lu[..., :k, :]), zero, u_padding) + u_eye

    la = triangular_solve(l,
                          x,
                          left_side=True,
                          transpose_a=False,
                          lower=True,
                          unit_diagonal=True)
    lau = triangular_solve(u,
                           la,
                           left_side=False,
                           transpose_a=False,
                           lower=False)

    l_dot = jnp.matmul(l, jnp.tril(lau, -1))
    u_dot = jnp.matmul(jnp.triu(lau), u)
    lu_dot = l_dot + u_dot
    return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_value(pivots),
                                       ad_util.Zero.from_value(permutation))
Esempio n. 2
0
def triangular_solve_jvp_rule_a(g_a, ans, a, b, left_side, lower, transpose_a,
                                conjugate_a, unit_diagonal):
    m, n = b.shape[-2:]
    k = 1 if unit_diagonal else 0
    g_a = jnp.tril(g_a, k=-k) if lower else jnp.triu(g_a, k=k)
    g_a = lax.neg(g_a)
    g_a = jnp.swapaxes(g_a, -1, -2) if transpose_a else g_a
    g_a = jnp.conj(g_a) if conjugate_a else g_a
    dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul,
                  precision=lax.Precision.HIGHEST)

    def a_inverse(rhs):
        return triangular_solve(a, rhs, left_side, lower, transpose_a,
                                conjugate_a, unit_diagonal)

    # triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs
    # for matrix/vector inputs). Order these operations in whichever order is
    # cheaper.
    if left_side:
        assert g_a.shape[-2:] == a.shape[-2:] == (m, m) and ans.shape[-2:] == (
            m, n)
        if m > n:
            return a_inverse(dot(g_a, ans))  # A^{-1} (∂A X)
        else:
            return dot(a_inverse(g_a), ans)  # (A^{-1} ∂A) X
    else:
        assert g_a.shape[-2:] == a.shape[-2:] == (n, n) and ans.shape[-2:] == (
            m, n)
        if m < n:
            return a_inverse(dot(ans, g_a))  # (X ∂A) A^{-1}
        else:
            return dot(ans, a_inverse(g_a))  # X (∂A A^{-1})
Esempio n. 3
0
def funm(A, func, disp=True):
  A = jnp.asarray(A)
  if A.ndim != 2 or A.shape[0] != A.shape[1]:
    raise ValueError('expected square array_like input')

  T, Z = schur(A)
  T, Z = rsf2csf(T, Z)

  F = jnp.diag(func(jnp.diag(T)))
  F = F.astype(T.dtype.char)

  F, minden = _algorithm_11_1_1(F, T)
  F = Z @ F @ Z.conj().T

  if disp:
    return F

  if F.dtype.char.lower() == 'e':
    tol = jnp.finfo(jnp.float16).eps
  if F.dtype.char.lower() == 'f':
    tol = jnp.finfo(jnp.float32).eps
  else:
    tol = jnp.finfo(jnp.float64).eps

  minden = jnp.where(minden == 0.0, tol, minden)
  err = jnp.where(jnp.any(jnp.isinf(F)), jnp.inf, jnp.minimum(1, jnp.maximum(
          tol, (tol / minden) * norm(jnp.triu(T, 1), 1))))

  return F, err
Esempio n. 4
0
def _lu(a, permute_l):
    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    lu, pivots, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    m, n = jnp.shape(a)
    p = jnp.real(jnp.array(permutation == jnp.arange(m)[:, None], dtype=dtype))
    k = min(m, n)
    l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
    u = jnp.triu(lu)[:k, :]
    if permute_l:
        return jnp.matmul(p, l), u
    else:
        return p, l, u
Esempio n. 5
0
def _lu(a, permute_l):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    lu, _, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    m, n = jnp.shape(a)
    p = jnp.real(
        jnp.array(permutation[None, :] == jnp.arange(
            m, dtype=permutation.dtype)[:, None],
                  dtype=dtype))
    k = min(m, n)
    l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
    u = jnp.triu(lu)[:k, :]
    if permute_l:
        return jnp.matmul(p, l), u
    else:
        return p, l, u
Esempio n. 6
0
def triu(m, k=0):
    return jnp.triu(m, k)