def cast(tensor, dtype): """Casts the given tensor to a new type. Args: tensor (tensor_like): tensor to cast dtype (str, np.dtype): Any supported NumPy dtype representation; this can be a string (``"float64"``), a ``np.dtype`` object (``np.dtype("float64")``), or a dtype class (``np.float64``). If ``tensor`` is not a NumPy array, the **equivalent** dtype in the dispatched framework is used. Returns: tensor_like: a tensor with the same shape and values as ``tensor`` and the same dtype as ``dtype`` **Example** We can use NumPy dtype specifiers: >>> x = torch.tensor([1, 2]) >>> cast(x, np.float64) tensor([1., 2.], dtype=torch.float64) We can also use strings: >>> x = tf.Variable([1, 2]) >>> cast(x, "complex128") <tf.Tensor: shape=(2,), dtype=complex128, numpy=array([1.+0.j, 2.+0.j])> """ if isinstance(tensor, (list, tuple)): tensor = np.asarray(tensor) if not isinstance(dtype, str): try: dtype = np.dtype(dtype).name except (AttributeError, TypeError): dtype = getattr(dtype, "name", dtype) return ar.astype(tensor, ar.to_backend_dtype(dtype, like=ar.infer_backend(tensor)))
def gen_rand(shape, backend, dtype='float64'): if backend == 'jax': from jax import random as jrandom global JAX_RANDOM_KEY if JAX_RANDOM_KEY is None: JAX_RANDOM_KEY = jrandom.PRNGKey(42) JAX_RANDOM_KEY, subkey = jrandom.split(JAX_RANDOM_KEY) return jrandom.uniform(subkey, shape=shape, dtype=dtype) elif backend == 'sparse': return ar.do('random.uniform', size=shape, like=backend, density=0.5, format='coo', fill_value=0) x = ar.do('random.uniform', size=shape, like=backend) x = ar.astype(x, ar.to_backend_dtype(dtype, backend)) assert ar.get_dtype_name(x) == dtype return x