def val_to_typecode(val): dtype = dtypes.result_type(val) weak_type = dtypes.is_weakly_typed(val) typecode = dtype_to_typecode[dtype] if weak_type: typecode = typecode[:-1] + '*' return typecode
def val_to_typecode(val): dtype = dtypes.result_type(val) weak_type = dtypes.is_python_scalar(val) typecode = dtype_to_typecode[dtype] if weak_type: typecode = typecode[:-1] + '*' return typecode
def fix_float0(arg_jax, ct_arg_jax): arg_dtype = dtypes.result_type(arg_jax) # May be scalar ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype) if ct_arg_dtype != ct_arg_jax.dtype: return ad_util.zeros_like_aval( core.ShapedArray(np.shape(arg_jax), ct_arg_dtype)) return ct_arg_jax
def std_basis(pytree: PyTree) -> PyTree: """Similar to `jax.api._std_basis` without host-side ops.""" leaves, _ = tree_flatten(pytree) ndim = sum(map(np.size, leaves)) dtype = dtypes.result_type(*leaves) flat_basis = np.eye(ndim, dtype=dtype) return unravel_array_into_pytree(pytree, 1, flat_basis)
def _promote_to_real(arg): dtype = dtypes.result_type(arg, np.float32) # XLA's FFT op only supports F32. # TODO(phawkins): remove when minimum jaxlib version is 0.1.48 or newer. if lib.version <= (0, 1, 47) and dtype == np.float64: dtype = np.float32 return lax.convert_element_type(arg, dtype)
def _promote_to_complex(arg): dtype = dtypes.result_type(arg, np.complex64) # XLA's FFT op only supports C64 in jaxlib versions 0.1.47 and earlier. # TODO(phawkins): remove when minimum jaxlib version is 0.1.48 or newer. if lib.version <= (0, 1, 47) and dtype == np.complex128: dtype = np.complex64 return lax.convert_element_type(arg, dtype)
def _ravel_list(lst): if not lst: return jnp.array([], jnp.float32), lambda _: [] from_dtypes = [dtypes.dtype(l) for l in lst] to_dtype = dtypes.result_type(*from_dtypes) sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst) indices = np.cumsum(sizes) def unravel(arr): chunks = jnp.split(arr, indices[:-1]) with warnings.catch_warnings(): warnings.simplefilter( "ignore") # ignore complex-to-real cast warning return [ lax.convert_element_type(chunk.reshape(shape), dtype) for chunk, shape, dtype in zip(chunks, shapes, from_dtypes) ] ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype)) raveled = jnp.concatenate([ravel(e) for e in lst]) return raveled, unravel
def _promote_to_real(arg): dtype = dtypes.result_type(arg, np.float32) return lax.convert_element_type(arg, dtype)
def _promote_to_complex(arg): dtype = dtypes.result_type(arg, np.complex64) return lax.convert_element_type(arg, dtype)
def testUnaryPromotion(self, dtype, weak_type): # Regression test for https://github.com/google/jax/issues/6051 x = lax._convert_element_type(0, dtype, weak_type=weak_type) y = jnp.array(0, dtype=dtypes.result_type(x)) assert x.dtype == y.dtype
def abstractify(x): return ShapedArray(np.shape(x), dtypes.result_type(x))
def _std_basis(pytree): leaves, _ = tree_flatten(pytree) ndim = sum(safe_map(np.size, leaves)) dtype = dtypes.result_type(*leaves) flat_basis = jax.numpy.eye(ndim, dtype=dtype) return _unravel_array_into_pytree(pytree, 1, flat_basis)