예제 #1
0
파일: linalg.py 프로젝트: ahoenselaar/jax
def solve(a, b):
    a, b = _promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b))
    return lax_linalg._solve(a, b)
예제 #2
0
def solve(a, b):
  a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
  return lax_linalg._solve(a, b)