Example #1
0
 def val_to_typecode(val):
     dtype = dtypes.result_type(val)
     weak_type = dtypes.is_weakly_typed(val)
     typecode = dtype_to_typecode[dtype]
     if weak_type:
         typecode = typecode[:-1] + '*'
     return typecode
Example #2
0
 def val_to_typecode(val):
   dtype = dtypes.result_type(val)
   weak_type = dtypes.is_python_scalar(val)
   typecode = dtype_to_typecode[dtype]
   if weak_type:
     typecode = typecode[:-1] + '*'
   return typecode
Example #3
0
 def fix_float0(arg_jax, ct_arg_jax):
     arg_dtype = dtypes.result_type(arg_jax)  # May be scalar
     ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype)
     if ct_arg_dtype != ct_arg_jax.dtype:
         return ad_util.zeros_like_aval(
             core.ShapedArray(np.shape(arg_jax), ct_arg_dtype))
     return ct_arg_jax
Example #4
0
def std_basis(pytree: PyTree) -> PyTree:
    """Similar to `jax.api._std_basis` without host-side ops."""
    leaves, _ = tree_flatten(pytree)
    ndim = sum(map(np.size, leaves))
    dtype = dtypes.result_type(*leaves)
    flat_basis = np.eye(ndim, dtype=dtype)
    return unravel_array_into_pytree(pytree, 1, flat_basis)
Example #5
0
File: fft.py Project: qqsun8819/jax
def _promote_to_real(arg):
    dtype = dtypes.result_type(arg, np.float32)
    # XLA's FFT op only supports F32.
    # TODO(phawkins): remove when minimum jaxlib version is 0.1.48 or newer.
    if lib.version <= (0, 1, 47) and dtype == np.float64:
        dtype = np.float32
    return lax.convert_element_type(arg, dtype)
Example #6
0
File: fft.py Project: qqsun8819/jax
def _promote_to_complex(arg):
    dtype = dtypes.result_type(arg, np.complex64)
    # XLA's FFT op only supports C64 in jaxlib versions 0.1.47 and earlier.
    # TODO(phawkins): remove when minimum jaxlib version is 0.1.48 or newer.
    if lib.version <= (0, 1, 47) and dtype == np.complex128:
        dtype = np.complex64
    return lax.convert_element_type(arg, dtype)
Example #7
0
def _ravel_list(lst):
    if not lst: return jnp.array([], jnp.float32), lambda _: []
    from_dtypes = [dtypes.dtype(l) for l in lst]
    to_dtype = dtypes.result_type(*from_dtypes)
    sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst)
    indices = np.cumsum(sizes)

    def unravel(arr):
        chunks = jnp.split(arr, indices[:-1])
        with warnings.catch_warnings():
            warnings.simplefilter(
                "ignore")  # ignore complex-to-real cast warning
            return [
                lax.convert_element_type(chunk.reshape(shape), dtype)
                for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)
            ]

    ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype))
    raveled = jnp.concatenate([ravel(e) for e in lst])
    return raveled, unravel
Example #8
0
def _promote_to_real(arg):
    dtype = dtypes.result_type(arg, np.float32)
    return lax.convert_element_type(arg, dtype)
Example #9
0
def _promote_to_complex(arg):
    dtype = dtypes.result_type(arg, np.complex64)
    return lax.convert_element_type(arg, dtype)
Example #10
0
 def testUnaryPromotion(self, dtype, weak_type):
     # Regression test for https://github.com/google/jax/issues/6051
     x = lax._convert_element_type(0, dtype, weak_type=weak_type)
     y = jnp.array(0, dtype=dtypes.result_type(x))
     assert x.dtype == y.dtype
Example #11
0
def abstractify(x):
    return ShapedArray(np.shape(x), dtypes.result_type(x))
Example #12
0
def _std_basis(pytree):
    leaves, _ = tree_flatten(pytree)
    ndim = sum(safe_map(np.size, leaves))
    dtype = dtypes.result_type(*leaves)
    flat_basis = jax.numpy.eye(ndim, dtype=dtype)
    return _unravel_array_into_pytree(pytree, 1, flat_basis)