def triangular_solve_jvp_rule_a(g_a, ans, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal): m, n = b.shape[-2:] k = 1 if unit_diagonal else 0 g_a = jnp.tril(g_a, k=-k) if lower else jnp.triu(g_a, k=k) g_a = lax.neg(g_a) g_a = jnp.swapaxes(g_a, -1, -2) if transpose_a else g_a g_a = jnp.conj(g_a) if conjugate_a else g_a dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul, precision=lax.Precision.HIGHEST) def a_inverse(rhs): return triangular_solve(a, rhs, left_side, lower, transpose_a, conjugate_a, unit_diagonal) # triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs # for matrix/vector inputs). Order these operations in whichever order is # cheaper. if left_side: assert g_a.shape[-2:] == a.shape[-2:] == (m, m) and ans.shape[-2:] == ( m, n) if m > n: return a_inverse(dot(g_a, ans)) # A^{-1} (∂A X) else: return dot(a_inverse(g_a), ans) # (A^{-1} ∂A) X else: assert g_a.shape[-2:] == a.shape[-2:] == (n, n) and ans.shape[-2:] == ( m, n) if m < n: return a_inverse(dot(ans, g_a)) # (X ∂A) A^{-1} else: return dot(ans, a_inverse(g_a)) # X (∂A A^{-1})
from functools import partial import numpy as np import textwrap import operator from typing import Tuple, Union, cast from jax import jit, custom_jvp from jax import lax from jax._src.lax import linalg as lax_linalg from jax._src import dtypes from jax._src.numpy.util import _wraps from jax._src.numpy import lax_numpy as jnp from jax._src.util import canonicalize_axis _T = lambda x: jnp.swapaxes(x, -1, -2) _H = lambda x: jnp.conjugate(jnp.swapaxes(x, -1, -2)) def _promote_arg_dtypes(*args): """Promotes `args` to a common inexact type.""" dtype, weak_type = dtypes._lattice_result_type(*args) if not jnp.issubdtype(dtype, jnp.inexact): dtype, weak_type = jnp.float_, False dtype = dtypes.canonicalize_dtype(dtype) args = [lax._convert_element_type(arg, dtype, weak_type) for arg in args] if len(args) == 1: return args[0] else: return args
import numpy as np import scipy.linalg import textwrap import warnings import jax from jax import jit, vmap, jvp from jax import lax from jax._src import dtypes from jax._src.lax import linalg as lax_linalg from jax._src.lax import qdwh from jax._src.numpy.util import _wraps, _promote_dtypes_inexact, _promote_dtypes_complex from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import linalg as np_linalg _T = lambda x: jnp.swapaxes(x, -1, -2) _no_chkfinite_doc = textwrap.dedent(""" Does not support the Scipy argument ``check_finite=True``, because compiled JAX code cannot perform checks of array values at runtime. """) _no_overwrite_and_chkfinite_doc = _no_chkfinite_doc + "\nDoes not support the Scipy argument ``overwrite_*=True``." @partial(jit, static_argnames=('lower', )) def _cholesky(a, lower): a, = _promote_dtypes_inexact(jnp.asarray(a)) l = lax_linalg.cholesky(a if lower else jnp.conj(_T(a)), symmetrize_input=False) return l if lower else jnp.conj(_T(l))
def _T(x): return jnp.swapaxes(x, -1, -2)
def _T(x): return jnp.swapaxes(x, -1, -2) def _H(x): return jnp.conj(_T(x))
def _H(x): return jnp.conjugate(jnp.swapaxes(x, -1, -2))