Example #1
0
def triangular_solve_jvp_rule_a(g_a, ans, a, b, left_side, lower, transpose_a,
                                conjugate_a, unit_diagonal):
    m, n = b.shape[-2:]
    k = 1 if unit_diagonal else 0
    g_a = jnp.tril(g_a, k=-k) if lower else jnp.triu(g_a, k=k)
    g_a = lax.neg(g_a)
    g_a = jnp.swapaxes(g_a, -1, -2) if transpose_a else g_a
    g_a = jnp.conj(g_a) if conjugate_a else g_a
    dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul,
                  precision=lax.Precision.HIGHEST)

    def a_inverse(rhs):
        return triangular_solve(a, rhs, left_side, lower, transpose_a,
                                conjugate_a, unit_diagonal)

    # triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs
    # for matrix/vector inputs). Order these operations in whichever order is
    # cheaper.
    if left_side:
        assert g_a.shape[-2:] == a.shape[-2:] == (m, m) and ans.shape[-2:] == (
            m, n)
        if m > n:
            return a_inverse(dot(g_a, ans))  # A^{-1} (∂A X)
        else:
            return dot(a_inverse(g_a), ans)  # (A^{-1} ∂A) X
    else:
        assert g_a.shape[-2:] == a.shape[-2:] == (n, n) and ans.shape[-2:] == (
            m, n)
        if m < n:
            return a_inverse(dot(ans, g_a))  # (X ∂A) A^{-1}
        else:
            return dot(ans, a_inverse(g_a))  # X (∂A A^{-1})
Example #2
0
from functools import partial

import numpy as np
import textwrap
import operator
from typing import Tuple, Union, cast

from jax import jit, custom_jvp
from jax import lax
from jax._src.lax import linalg as lax_linalg
from jax._src import dtypes
from jax._src.numpy.util import _wraps
from jax._src.numpy import lax_numpy as jnp
from jax._src.util import canonicalize_axis

_T = lambda x: jnp.swapaxes(x, -1, -2)
_H = lambda x: jnp.conjugate(jnp.swapaxes(x, -1, -2))


def _promote_arg_dtypes(*args):
    """Promotes `args` to a common inexact type."""
    dtype, weak_type = dtypes._lattice_result_type(*args)
    if not jnp.issubdtype(dtype, jnp.inexact):
        dtype, weak_type = jnp.float_, False
    dtype = dtypes.canonicalize_dtype(dtype)
    args = [lax._convert_element_type(arg, dtype, weak_type) for arg in args]
    if len(args) == 1:
        return args[0]
    else:
        return args
Example #3
0
import numpy as np
import scipy.linalg
import textwrap
import warnings

import jax
from jax import jit, vmap, jvp
from jax import lax
from jax._src import dtypes
from jax._src.lax import linalg as lax_linalg
from jax._src.lax import qdwh
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact, _promote_dtypes_complex
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import linalg as np_linalg

_T = lambda x: jnp.swapaxes(x, -1, -2)
_no_chkfinite_doc = textwrap.dedent("""
Does not support the Scipy argument ``check_finite=True``,
because compiled JAX code cannot perform checks of array values at runtime.
""")
_no_overwrite_and_chkfinite_doc = _no_chkfinite_doc + "\nDoes not support the Scipy argument ``overwrite_*=True``."


@partial(jit, static_argnames=('lower', ))
def _cholesky(a, lower):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    l = lax_linalg.cholesky(a if lower else jnp.conj(_T(a)),
                            symmetrize_input=False)
    return l if lower else jnp.conj(_T(l))

Example #4
0
def _T(x):
    return jnp.swapaxes(x, -1, -2)
Example #5
0
def _T(x): return jnp.swapaxes(x, -1, -2)
def _H(x): return jnp.conj(_T(x))
Example #6
0
def _H(x):
  return jnp.conjugate(jnp.swapaxes(x, -1, -2))