Exemple #1
0
def debug_triangular_solve():
    from jax.lax_linalg import triangular_solve
    from scipy.linalg import solve_triangular
    from jax.scipy.linalg import solve_triangular as solve_triangular_jax
    import jax.numpy as jnp
    ndims = 2
    A = jnp.diag(jnp.ones(ndims))
    A = jnp.where(A == 0., 0.95, A)
    b = jnp.ones(ndims)
    L = jnp.linalg.cholesky(A)
    assert jnp.all(L @ L.T == A)

    x = jnp.linalg.solve(L, b)
    print("Solving L.x  = b with scipy")
    print("x should be {}".format(x))
    scipy_x = solve_triangular(L, b, lower=True)
    assert jnp.all(scipy_x == x)
    print("Works as expected!")

    print("Now note JAX's solution to L^T.x = b corresponds to scipy's L.x = b")
    jax_x = triangular_solve(L, b, lower=True, transpose_a=True)
    assert jnp.all(jax_x==scipy_x)
    print("Likewise, JAX's solution to L.x = b corresponds to scipy's L^T.x = b")
    assert jnp.all(triangular_solve(L, b, lower=True) == solve_triangular(L, b, lower=True, trans=1))

    print("Note, I have not tested for the L^H.x=b case.")

    jax_x = solve_triangular_jax(L, b, lower=True)
    assert jnp.all(scipy_x == jax_x)
Exemple #2
0
def _cho_solve(c, b, lower):
    c, b = np_linalg._promote_arg_dtypes(jnp.asarray(c), jnp.asarray(b))
    np_linalg._check_solve_shapes(c, b)
    b = lax_linalg.triangular_solve(c,
                                    b,
                                    left_side=True,
                                    lower=lower,
                                    transpose_a=not lower,
                                    conjugate_a=not lower)
    b = lax_linalg.triangular_solve(c,
                                    b,
                                    left_side=True,
                                    lower=lower,
                                    transpose_a=lower,
                                    conjugate_a=lower)
    return b
Exemple #3
0
def _solve_triangular(a, b, trans, lower, unit_diagonal):
    if trans == 0 or trans == "N":
        transpose_a, conjugate_a = False, False
    elif trans == 1 or trans == "T":
        transpose_a, conjugate_a = True, False
    elif trans == 2 or trans == "C":
        transpose_a, conjugate_a = True, True
    else:
        raise ValueError("Invalid 'trans' value {}".format(trans))

    a, b = np_linalg._promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b))

    # lax_linalg.triangular_solve only supports matrix 'b's at the moment.
    b_is_vector = jnp.ndim(a) == jnp.ndim(b) + 1
    if b_is_vector:
        b = b[..., None]
    out = lax_linalg.triangular_solve(a,
                                      b,
                                      left_side=True,
                                      lower=lower,
                                      transpose_a=transpose_a,
                                      conjugate_a=conjugate_a,
                                      unit_diagonal=unit_diagonal)
    if b_is_vector:
        return out[..., 0]
    else:
        return out
Exemple #4
0
  def testLaxLinalgTriangularSolve(self):
    a = onp.random.RandomState(0).randn(4, 10, 4).astype(onp.float32)
    a += onp.eye(4, dtype=np.float32)[:, None, :]
    b = onp.random.RandomState(0).randn(5, 4, 10).astype(onp.float32)

    ans = vmap(lax_linalg.triangular_solve, in_axes=(1, 2))(a, b)
    expected = onp.stack(
      [lax_linalg.triangular_solve(a[:, i], b[..., i]) for i in range(10)])
    self.assertAllClose(ans, expected, check_dtypes=True)

    ans = vmap(lax_linalg.triangular_solve, in_axes=(None, 2))(a[:, 0], b)
    expected = onp.stack(
      [lax_linalg.triangular_solve(a[:, 0], b[..., i]) for i in range(10)])
    self.assertAllClose(ans, expected, check_dtypes=True)

    ans = vmap(lax_linalg.triangular_solve, in_axes=(1, None))(a, b[..., 0])
    expected = onp.stack(
      [lax_linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10)])
    self.assertAllClose(ans, expected, check_dtypes=True)
Exemple #5
0
def logpdf(x, mean, cov):
    x, mean, cov = _promote_dtypes_inexact(x, mean, cov)
    if not mean.shape:
        return (-1 / 2 * jnp.square(x - mean) / cov - 1 / 2 *
                (np.log(2 * np.pi) + jnp.log(cov)))
    else:
        n = mean.shape[-1]
        if not np.shape(cov):
            y = x - mean
            return (-1 / 2 * jnp.einsum('...i,...i->...', y, y) / cov - n / 2 *
                    (np.log(2 * np.pi) + jnp.log(cov)))
        else:
            if cov.ndim < 2 or cov.shape[-2:] != (n, n):
                raise ValueError(
                    "multivariate_normal.logpdf got incompatible shapes")
            L = cholesky(cov)
            y = triangular_solve(L, x - mean, lower=True, transpose_a=True)
            return (-1 / 2 * jnp.einsum('...i,...i->...', y, y) -
                    n / 2 * np.log(2 * np.pi) - jnp.log(L.diagonal()).sum())
Exemple #6
0
def multivariate_gaussian_logp(x, mu, cov):
    """
    compute the log probability of a multivariate gaussian given x, mu, and a covariance matrix
    arguments
        x : jnp.array(N)
            position in latent space
        mu : jnp.array(N)
            mean of Gaussian
        cov : jnp.array(N,N)
            covariance of gaussian
    returns
        logp : float
            log probability of N(x | mu, cov)
    """
    from jax.lax_linalg import cholesky, triangular_solve
    n = mu.shape[0]
    L = cholesky(cov)
    y = triangular_solve(L, x - mu, lower=True, transpose_a=True)
    return -1. / 2. * jnp.einsum('...i,...i->...', y, y) - n / 2. * jnp.log(
        2 * jnp.pi) - jnp.log(L.diagonal()).sum()
Exemple #7
0
def _cofactor_solve(a, b):
    """Equivalent to det(a)*solve(a, b) for nonsingular mat.

  Intermediate function used for jvp and vjp of det.
  This function borrows heavily from jax.numpy.linalg.solve and
  jax.numpy.linalg.slogdet to compute the gradient of the determinant
  in a way that is well defined even for low rank matrices.

  This function handles two different cases:
  * rank(a) == n or n-1
  * rank(a) < n-1

  For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
  Rather than computing det(a)*solve(a, b), which would return NaN, we work
  directly with the LU decomposition. If a = p @ l @ u, then
  det(a)*solve(a, b) =
  prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
  prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
  If a is rank n-1, then the lower right corner of u will be zero and the
  triangular_solve will fail.
  Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
  Then y_{n}
  x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
  x_{n} * prod_{i=1...n-1}(u_{ii})
  So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
  we can avoid the triangular_solve failing.
  To correctly compute the rest of y_{i} for i != n, we simply multiply
  x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.

  For the second case, a check is done on the matrix to see if `solve`
  returns NaN or Inf, and gives a matrix of zeros as a result, as the
  gradient of the determinant of a matrix with rank less than n-1 is 0.
  This will still return the correct value for rank n-1 matrices, as the check
  is applied *after* the lower right corner of u has been updated.

  Args:
    a: A square matrix or batch of matrices, possibly singular.
    b: A matrix, or batch of matrices of the same dimension as a.

  Returns:
    det(a) and cofactor(a)^T*b, aka adjugate(a)*b
  """
    a = _promote_arg_dtypes(jnp.asarray(a))
    b = _promote_arg_dtypes(jnp.asarray(b))
    a_shape = jnp.shape(a)
    b_shape = jnp.shape(b)
    a_ndims = len(a_shape)
    if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
            and b_shape[-2:] == a_shape[-2:]):
        msg = ("The arguments to _cofactor_solve must have shapes "
               "a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
        raise ValueError(msg.format(a_shape, b_shape))
    if a_shape[-1] == 1:
        return a[0, 0], b
    # lu contains u in the upper triangular matrix and l in the strict lower
    # triangular matrix.
    # The diagonal of l is set to ones without loss of generality.
    lu, pivots, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
    x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
    lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:])
    # Compute (partial) determinant, ignoring last diagonal of LU
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1)
    sign = jnp.array(-2 * (parity % 2) + 1, dtype=dtype)
    # partial_det[:, -1] contains the full determinant and
    # partial_det[:, -2] contains det(u) / u_{nn}.
    partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
    lu = ops.index_update(lu, ops.index[..., -1, -1],
                          1.0 / partial_det[..., -2])
    permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1], ))
    iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, )))
    # filter out any matrices that are not full rank
    d = jnp.ones(x.shape[:-1], x.dtype)
    d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
    d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1)
    d = jnp.tile(d[..., None, None], d.ndim * (1, ) + x.shape[-2:])
    x = jnp.where(d, jnp.zeros_like(x), x)  # first filter
    x = x[iotas[:-1] + (permutation, slice(None))]
    x = lax_linalg.triangular_solve(lu,
                                    x,
                                    left_side=True,
                                    lower=True,
                                    unit_diagonal=True)
    x = jnp.concatenate(
        (x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]),
        axis=-2)
    x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
    x = jnp.where(d, jnp.zeros_like(x), x)  # second filter

    return partial_det[..., -1], x