def dtype(self, array) -> DType: if isinstance(array, int): return DType(int, 32) if isinstance(array, float): return DType(float, 64) if isinstance(array, complex): return DType(complex, 128) if not isinstance(array, jnp.ndarray): array = jnp.array(array) return from_numpy_dtype(array.dtype)
def test_from_numpy_dtype(self): self.assertEqual(from_numpy_dtype(np.bool), DType(bool)) self.assertEqual(from_numpy_dtype(np.bool_), DType(bool)) self.assertEqual(from_numpy_dtype(np.int32), DType(int, 32)) self.assertEqual(from_numpy_dtype(np.array(0, np.int32).dtype), DType(int, 32)) self.assertEqual(from_numpy_dtype(np.array(0, bool).dtype), DType(bool)) self.assertEqual(from_numpy_dtype(np.array(0, np.object).dtype), DType(object))
def test_random_int(self): for backend in BACKENDS: with backend: # 32 bits a = math.random_uniform(instance(values=1000), low=-1, high=1, dtype=(int, 32)) self.assertEqual(a.dtype, DType(int, 32), msg=backend.name) self.assertEqual(a.min, -1, msg=backend.name) self.assertEqual(a.max, 0, msg=backend.name) # 64 bits a = math.random_uniform(instance(values=1000), low=-1, high=1, dtype=(int, 64)) self.assertEqual(a.dtype.kind, int, msg=backend.name) # Jax may downcast 64-bit to 32 self.assertEqual(a.min, -1, msg=backend.name) self.assertEqual(a.max, 0, msg=backend.name)
def test_repr(self): print("--- Eager ---") print(repr(math.zeros(batch(b=10)))) print(repr(math.zeros(batch(b=10)) > 0)) print(repr(math.ones(channel(vector=3)))) print(repr(math.ones(channel(vector=3), dtype=DType(int, 64)))) print(repr(math.ones(channel(vector=3), dtype=DType(float, 64)))) print(repr(math.ones(batch(vector=3)))) print(repr(math.random_normal(batch(b=10)))) print( repr( math.random_normal(batch(b=10), dtype=DType(float, 64)) * 1e-6)) def tracable(x): print(x) return x print("--- Placeholders ---") for backend in BACKENDS: if backend.supports(Backend.jit_compile): with backend: math.jit_compile(tracable)(math.ones(channel(vector=3)))
def from_torch_dtype(torch_dtype): if torch_dtype in _FROM_TORCH: return _FROM_TORCH[torch_dtype] else: kind = {'i': int, 'b': bool, 'f': float, 'c': complex}[torch_dtype.kind] return DType(kind, torch_dtype.itemsize * 8)
def imag(self, x): dtype = self.dtype(x) if dtype.kind == complex: return torch.imag(x) else: return self.zeros(x.shape, DType(float, dtype.precision))
def range(self, start, limit=None, delta=1, dtype: DType = DType(int, 32)): if limit is None: start, limit = 0, start return torch.arange(start, limit, delta, dtype=to_torch_dtype(dtype))
def random_uniform(self, shape, low, high, dtype: DType or None): dtype = dtype or self.float_type if dtype.kind == float: return low + (high - low) * torch.rand(size=shape, dtype=to_torch_dtype(dtype), device=self.get_default_device().ref) elif dtype.kind == complex: real = low.real + (high.real - low.real) * torch.rand(size=shape, dtype=to_torch_dtype(DType(float, dtype.precision)), device=self.get_default_device().ref) imag = low.imag + (high.imag - low.imag) * torch.rand(size=shape, dtype=to_torch_dtype(DType(float, dtype.precision)), device=self.get_default_device().ref) return real + 1j * imag elif dtype.kind == int: return torch.randint(low, high, shape, dtype=to_torch_dtype(dtype)) else: raise ValueError(dtype)
def to_torch_dtype(dtype: DType): return _TO_TORCH[dtype] def from_torch_dtype(torch_dtype): if torch_dtype in _FROM_TORCH: return _FROM_TORCH[torch_dtype] else: kind = {'i': int, 'b': bool, 'f': float, 'c': complex}[torch_dtype.kind] return DType(kind, torch_dtype.itemsize * 8) _TO_TORCH = { DType(float, 16): torch.float16, DType(float, 32): torch.float32, DType(float, 64): torch.float64, DType(complex, 64): torch.complex64, DType(complex, 128): torch.complex128, DType(int, 8): torch.int8, DType(int, 16): torch.int16, DType(int, 32): torch.int32, DType(int, 64): torch.int64, DType(bool): torch.bool, } _FROM_TORCH = {np: dtype for dtype, np in _TO_TORCH.items()} @torch.jit._script_if_tracing def torch_sparse_cg(lin, y, x0, rtol, atol, max_iter):
def range(self, start, limit=None, delta=1, dtype: DType = DType(int, 32)): if limit is None: start, limit = 0, start return jnp.arange(start, limit, delta, to_numpy_dtype(dtype))
def random_uniform(self, shape, low, high, dtype: DType or None): self._check_float64() self.rnd_key, subkey = jax.random.split(self.rnd_key) dtype = dtype or self.float_type jdt = to_numpy_dtype(dtype) if dtype.kind == float: tensor = random.uniform(subkey, shape, minval=low, maxval=high, dtype=jdt) elif dtype.kind == complex: real = random.uniform(subkey, shape, minval=low.real, maxval=high.real, dtype=to_numpy_dtype(DType(float, dtype.precision))) imag = random.uniform(subkey, shape, minval=low.imag, maxval=high.imag, dtype=to_numpy_dtype(DType(float, dtype.precision))) return real + 1j * imag elif dtype.kind == int: tensor = random.randint(subkey, shape, low, high, dtype=jdt) if tensor.dtype != jdt: warnings.warn(f"Jax failed to sample random integers with dtype {dtype}, returned {tensor.dtype} instead.", RuntimeWarning) else: raise ValueError(dtype) return jax.device_put(tensor, self._default_device.ref)
def test_as_dtype(self): self.assertEqual(None, DType.as_dtype(None)) self.assertEqual(DType(int, 32), DType.as_dtype(DType(int, 32))) self.assertEqual(DType(int, 32), DType.as_dtype(int)) self.assertEqual(DType(float, 32), DType.as_dtype(float)) self.assertEqual(DType(complex, 64), DType.as_dtype(complex)) self.assertEqual(DType(bool), DType.as_dtype(bool)) self.assertEqual(DType(int, 8), DType.as_dtype((int, 8))) self.assertEqual(object, DType.as_dtype(object).kind) try: DType.as_dtype(str) self.fail() except ValueError: pass
def test_object_dtype(self): self.assertIn(DType(object).bits, (32, 64))
def test_cast(self): for backend in BACKENDS: with backend: x = math.random_uniform(dtype=DType(float, 64)) self.assertEqual(DType(float, 32), math.to_float(x).dtype, msg=backend.name) self.assertEqual(DType(complex, 64), math.to_complex(x).dtype, msg=backend.name) with math.precision(64): self.assertEqual(DType(float, 64), math.to_float(x).dtype, msg=backend.name) self.assertEqual(DType(complex, 128), math.to_complex(x).dtype, msg=backend.name) self.assertEqual(DType(int, 64), math.to_int64(x).dtype, msg=backend.name) self.assertEqual(DType(int, 32), math.to_int32(x).dtype, msg=backend.name) self.assertEqual(DType(float, 16), math.cast(x, DType(float, 16)).dtype, msg=backend.name) self.assertEqual(DType(complex, 128), math.cast(x, DType(complex, 128)).dtype, msg=backend.name) try: math.cast(x, DType(float, 3)) self.fail(msg=backend.name) except KeyError: pass
def test_random_complex(self): for backend in BACKENDS: with backend: a = math.random_uniform(instance(values=4), low=-1, high=0, dtype=(complex, 64)) self.assertEqual(a.dtype, DType(complex, 64), msg=backend.name) math.assert_close(a.imag, 0, msg=backend.name)