def test_complex_creation(backend, real_dtype): if backend == 'torch': pytest.xfail("Pytorch doesn't support complex numbers yet...") if (backend == 'sparse') and (real_dtype == 'float32'): pytest.xfail("Bug in sparse where single precision isn't maintained " "after scalar multiplication.") x = ar.do( 'complex', ar.astype(ar.do('random.normal', size=(3, 4), like=backend), real_dtype), ar.astype(ar.do('random.normal', size=(3, 4), like=backend), real_dtype)) assert ar.get_dtype_name(x) == { 'float32': 'complex64', 'float64': 'complex128' }[real_dtype]
def tf_qr(x): U, s, VH = autoray.do('linalg.svd', x) dtype = autoray.get_dtype_name(U) if 'complex' in dtype: s = autoray.astype(s, dtype) Q = U R = autoray.reshape(s, (-1, 1)) * VH return Q, R
def test_dtype_specials(backend, creation, dtype): import numpy as np x = ar.do(creation, shape=(2, 3), like=backend) if backend == 'torch' and 'complex' in dtype: pytest.xfail("Pytorch doesn't support complex numbers yet...") x = ar.astype(x, dtype) assert ar.get_dtype_name(x) == dtype x = ar.to_numpy(x) assert isinstance(x, np.ndarray) assert ar.get_dtype_name(x) == dtype
def cast(tensor, dtype): """Casts the given tensor to a new type. Args: tensor (tensor_like): tensor to cast dtype (str, np.dtype): Any supported NumPy dtype representation; this can be a string (``"float64"``), a ``np.dtype`` object (``np.dtype("float64")``), or a dtype class (``np.float64``). If ``tensor`` is not a NumPy array, the **equivalent** dtype in the dispatched framework is used. Returns: tensor_like: a tensor with the same shape and values as ``tensor`` and the same dtype as ``dtype`` **Example** We can use NumPy dtype specifiers: >>> x = torch.tensor([1, 2]) >>> cast(x, np.float64) tensor([1., 2.], dtype=torch.float64) We can also use strings: >>> x = tf.Variable([1, 2]) >>> cast(x, "complex128") <tf.Tensor: shape=(2,), dtype=complex128, numpy=array([1.+0.j, 2.+0.j])> """ if isinstance(tensor, (list, tuple)): tensor = np.asarray(tensor) if not isinstance(dtype, str): try: dtype = np.dtype(dtype).name except (AttributeError, TypeError): dtype = getattr(dtype, "name", dtype) return ar.astype(tensor, ar.to_backend_dtype(dtype, like=ar.infer_backend(tensor)))
def gen_rand(shape, backend, dtype='float64'): if backend == 'jax': from jax import random as jrandom global JAX_RANDOM_KEY if JAX_RANDOM_KEY is None: JAX_RANDOM_KEY = jrandom.PRNGKey(42) JAX_RANDOM_KEY, subkey = jrandom.split(JAX_RANDOM_KEY) return jrandom.uniform(subkey, shape=shape, dtype=dtype) elif backend == 'sparse': return ar.do('random.uniform', size=shape, like=backend, density=0.5, format='coo', fill_value=0) x = ar.do('random.uniform', size=shape, like=backend) x = ar.astype(x, ar.to_backend_dtype(dtype, backend)) assert ar.get_dtype_name(x) == dtype return x