Exemple #1
0
 def dtype(self, array) -> DType:
     if isinstance(array, int):
         return DType(int, 32)
     if isinstance(array, float):
         return DType(float, 64)
     if isinstance(array, complex):
         return DType(complex, 128)
     if not isinstance(array, jnp.ndarray):
         array = jnp.array(array)
     return from_numpy_dtype(array.dtype)
Exemple #2
0
 def test_from_numpy_dtype(self):
     self.assertEqual(from_numpy_dtype(np.bool), DType(bool))
     self.assertEqual(from_numpy_dtype(np.bool_), DType(bool))
     self.assertEqual(from_numpy_dtype(np.int32), DType(int, 32))
     self.assertEqual(from_numpy_dtype(np.array(0, np.int32).dtype), DType(int, 32))
     self.assertEqual(from_numpy_dtype(np.array(0, bool).dtype), DType(bool))
     self.assertEqual(from_numpy_dtype(np.array(0, np.object).dtype), DType(object))
Exemple #3
0
 def test_random_int(self):
     for backend in BACKENDS:
         with backend:
             # 32 bits
             a = math.random_uniform(instance(values=1000), low=-1, high=1, dtype=(int, 32))
             self.assertEqual(a.dtype, DType(int, 32), msg=backend.name)
             self.assertEqual(a.min, -1, msg=backend.name)
             self.assertEqual(a.max, 0, msg=backend.name)
             # 64 bits
             a = math.random_uniform(instance(values=1000), low=-1, high=1, dtype=(int, 64))
             self.assertEqual(a.dtype.kind, int, msg=backend.name)  # Jax may downcast 64-bit to 32
             self.assertEqual(a.min, -1, msg=backend.name)
             self.assertEqual(a.max, 0, msg=backend.name)
Exemple #4
0
    def test_repr(self):
        print("--- Eager ---")
        print(repr(math.zeros(batch(b=10))))
        print(repr(math.zeros(batch(b=10)) > 0))
        print(repr(math.ones(channel(vector=3))))
        print(repr(math.ones(channel(vector=3), dtype=DType(int, 64))))
        print(repr(math.ones(channel(vector=3), dtype=DType(float, 64))))
        print(repr(math.ones(batch(vector=3))))
        print(repr(math.random_normal(batch(b=10))))
        print(
            repr(
                math.random_normal(batch(b=10), dtype=DType(float, 64)) *
                1e-6))

        def tracable(x):
            print(x)
            return x

        print("--- Placeholders ---")
        for backend in BACKENDS:
            if backend.supports(Backend.jit_compile):
                with backend:
                    math.jit_compile(tracable)(math.ones(channel(vector=3)))
Exemple #5
0
def from_torch_dtype(torch_dtype):
    if torch_dtype in _FROM_TORCH:
        return _FROM_TORCH[torch_dtype]
    else:
        kind = {'i': int, 'b': bool, 'f': float, 'c': complex}[torch_dtype.kind]
        return DType(kind, torch_dtype.itemsize * 8)
Exemple #6
0
 def imag(self, x):
     dtype = self.dtype(x)
     if dtype.kind == complex:
         return torch.imag(x)
     else:
         return self.zeros(x.shape, DType(float, dtype.precision))
Exemple #7
0
 def range(self, start, limit=None, delta=1, dtype: DType = DType(int, 32)):
     if limit is None:
         start, limit = 0, start
     return torch.arange(start, limit, delta, dtype=to_torch_dtype(dtype))
Exemple #8
0
 def random_uniform(self, shape, low, high, dtype: DType or None):
     dtype = dtype or self.float_type
     if dtype.kind == float:
         return low + (high - low) * torch.rand(size=shape, dtype=to_torch_dtype(dtype), device=self.get_default_device().ref)
     elif dtype.kind == complex:
         real = low.real + (high.real - low.real) * torch.rand(size=shape, dtype=to_torch_dtype(DType(float, dtype.precision)), device=self.get_default_device().ref)
         imag = low.imag + (high.imag - low.imag) * torch.rand(size=shape, dtype=to_torch_dtype(DType(float, dtype.precision)), device=self.get_default_device().ref)
         return real + 1j * imag
     elif dtype.kind == int:
         return torch.randint(low, high, shape, dtype=to_torch_dtype(dtype))
     else:
         raise ValueError(dtype)
Exemple #9
0

def to_torch_dtype(dtype: DType):
    return _TO_TORCH[dtype]


def from_torch_dtype(torch_dtype):
    if torch_dtype in _FROM_TORCH:
        return _FROM_TORCH[torch_dtype]
    else:
        kind = {'i': int, 'b': bool, 'f': float, 'c': complex}[torch_dtype.kind]
        return DType(kind, torch_dtype.itemsize * 8)


_TO_TORCH = {
    DType(float, 16): torch.float16,
    DType(float, 32): torch.float32,
    DType(float, 64): torch.float64,
    DType(complex, 64): torch.complex64,
    DType(complex, 128): torch.complex128,
    DType(int, 8): torch.int8,
    DType(int, 16): torch.int16,
    DType(int, 32): torch.int32,
    DType(int, 64): torch.int64,
    DType(bool): torch.bool,
}
_FROM_TORCH = {np: dtype for dtype, np in _TO_TORCH.items()}


@torch.jit._script_if_tracing
def torch_sparse_cg(lin, y, x0, rtol, atol, max_iter):
Exemple #10
0
 def range(self, start, limit=None, delta=1, dtype: DType = DType(int, 32)):
     if limit is None:
         start, limit = 0, start
     return jnp.arange(start, limit, delta, to_numpy_dtype(dtype))
Exemple #11
0
    def random_uniform(self, shape, low, high, dtype: DType or None):
        self._check_float64()
        self.rnd_key, subkey = jax.random.split(self.rnd_key)

        dtype = dtype or self.float_type
        jdt = to_numpy_dtype(dtype)
        if dtype.kind == float:
            tensor = random.uniform(subkey, shape, minval=low, maxval=high, dtype=jdt)
        elif dtype.kind == complex:
            real = random.uniform(subkey, shape, minval=low.real, maxval=high.real, dtype=to_numpy_dtype(DType(float, dtype.precision)))
            imag = random.uniform(subkey, shape, minval=low.imag, maxval=high.imag, dtype=to_numpy_dtype(DType(float, dtype.precision)))
            return real + 1j * imag
        elif dtype.kind == int:
            tensor = random.randint(subkey, shape, low, high, dtype=jdt)
            if tensor.dtype != jdt:
                warnings.warn(f"Jax failed to sample random integers with dtype {dtype}, returned {tensor.dtype} instead.", RuntimeWarning)
        else:
            raise ValueError(dtype)
        return jax.device_put(tensor, self._default_device.ref)
Exemple #12
0
 def test_as_dtype(self):
     self.assertEqual(None, DType.as_dtype(None))
     self.assertEqual(DType(int, 32), DType.as_dtype(DType(int, 32)))
     self.assertEqual(DType(int, 32), DType.as_dtype(int))
     self.assertEqual(DType(float, 32), DType.as_dtype(float))
     self.assertEqual(DType(complex, 64), DType.as_dtype(complex))
     self.assertEqual(DType(bool), DType.as_dtype(bool))
     self.assertEqual(DType(int, 8), DType.as_dtype((int, 8)))
     self.assertEqual(object, DType.as_dtype(object).kind)
     try:
         DType.as_dtype(str)
         self.fail()
     except ValueError:
         pass
Exemple #13
0
 def test_object_dtype(self):
     self.assertIn(DType(object).bits, (32, 64))
Exemple #14
0
 def test_cast(self):
     for backend in BACKENDS:
         with backend:
             x = math.random_uniform(dtype=DType(float, 64))
             self.assertEqual(DType(float, 32), math.to_float(x).dtype, msg=backend.name)
             self.assertEqual(DType(complex, 64), math.to_complex(x).dtype, msg=backend.name)
             with math.precision(64):
                 self.assertEqual(DType(float, 64), math.to_float(x).dtype, msg=backend.name)
                 self.assertEqual(DType(complex, 128), math.to_complex(x).dtype, msg=backend.name)
             self.assertEqual(DType(int, 64), math.to_int64(x).dtype, msg=backend.name)
             self.assertEqual(DType(int, 32), math.to_int32(x).dtype, msg=backend.name)
             self.assertEqual(DType(float, 16), math.cast(x, DType(float, 16)).dtype, msg=backend.name)
             self.assertEqual(DType(complex, 128), math.cast(x, DType(complex, 128)).dtype, msg=backend.name)
             try:
                 math.cast(x, DType(float, 3))
                 self.fail(msg=backend.name)
             except KeyError:
                 pass
Exemple #15
0
 def test_random_complex(self):
     for backend in BACKENDS:
         with backend:
             a = math.random_uniform(instance(values=4), low=-1, high=0, dtype=(complex, 64))
             self.assertEqual(a.dtype, DType(complex, 64), msg=backend.name)
             math.assert_close(a.imag, 0, msg=backend.name)