def test_device(t, FWDevice, check_lazy_shapes): a = B.randn(t, 2, 2) assert isinstance(B.device(a), FWDevice) assert isinstance(B.device(a), B.Device) # Test conversion to string. assert isinstance(convert(B.device(a), str), str)
def test_device_and_to_active_device(check_lazy_shapes): # Check moving a device to the CPU. for a in Tensor(2, 2).forms(): assert "cpu" in str(B.device(a)).lower() approx(B.to_active_device(a), a) # Check that numbers remain unchanged. a = 1 assert B.to_active_device(a) is a
def test_on_device(f, t, check_lazy_shapes): f_t = f(t) # Contruct on current and existing device. # Set the active device to something else. B.ActiveDevice.active_name = "previous" # Check that explicit allocation on CPU works. with B.on_device("cpu"): assert B.device(f(t)) == B.device(f_t) # Also test inferring the device from a tensor. with B.on_device(f_t): assert B.device(f(t)) == B.device(f_t) # Check that allocation on a non-existing device breaks. with pytest.raises(Exception): with B.on_device("magic-device"): f(t) # Check that the active device is reset again. assert B.ActiveDevice.active_name == "previous" B.ActiveDevice.active_name = None
def test_framework_jax(t, check_lazy_shapes): assert isinstance(jnp.asarray(1), t) assert isinstance(jnp.float32, t) assert isinstance(B.create_random_state(jnp.float32), t) assert isinstance(B.device(jnp.asarray(1)), t)
def test_framework_torch(t, check_lazy_shapes): assert isinstance(torch.tensor(1), t) assert isinstance(torch.float32, t) assert isinstance(B.create_random_state(torch.float32), t) assert isinstance(B.device(torch.tensor(1)), t)
def test_framework_tf(t, check_lazy_shapes): assert isinstance(tf.constant(1), t) assert isinstance(tf.float32, t) assert isinstance(B.create_random_state(tf.float32), t) assert isinstance(B.device(tf.constant(1)), t)