Example #1
0
def test_device(t, FWDevice, check_lazy_shapes):
    a = B.randn(t, 2, 2)
    assert isinstance(B.device(a), FWDevice)
    assert isinstance(B.device(a), B.Device)

    # Test conversion to string.
    assert isinstance(convert(B.device(a), str), str)
Example #2
0
def test_device_and_to_active_device(check_lazy_shapes):
    # Check moving a device to the CPU.
    for a in Tensor(2, 2).forms():
        assert "cpu" in str(B.device(a)).lower()
        approx(B.to_active_device(a), a)

    # Check that numbers remain unchanged.
    a = 1
    assert B.to_active_device(a) is a
Example #3
0
def test_on_device(f, t, check_lazy_shapes):
    f_t = f(t)  # Contruct on current and existing device.

    # Set the active device to something else.
    B.ActiveDevice.active_name = "previous"

    # Check that explicit allocation on CPU works.
    with B.on_device("cpu"):
        assert B.device(f(t)) == B.device(f_t)

    # Also test inferring the device from a tensor.
    with B.on_device(f_t):
        assert B.device(f(t)) == B.device(f_t)

    # Check that allocation on a non-existing device breaks.
    with pytest.raises(Exception):
        with B.on_device("magic-device"):
            f(t)

    # Check that the active device is reset again.
    assert B.ActiveDevice.active_name == "previous"
    B.ActiveDevice.active_name = None
Example #4
0
def test_framework_jax(t, check_lazy_shapes):
    assert isinstance(jnp.asarray(1), t)
    assert isinstance(jnp.float32, t)
    assert isinstance(B.create_random_state(jnp.float32), t)
    assert isinstance(B.device(jnp.asarray(1)), t)
Example #5
0
def test_framework_torch(t, check_lazy_shapes):
    assert isinstance(torch.tensor(1), t)
    assert isinstance(torch.float32, t)
    assert isinstance(B.create_random_state(torch.float32), t)
    assert isinstance(B.device(torch.tensor(1)), t)
Example #6
0
def test_framework_tf(t, check_lazy_shapes):
    assert isinstance(tf.constant(1), t)
    assert isinstance(tf.float32, t)
    assert isinstance(B.create_random_state(tf.float32), t)
    assert isinstance(B.device(tf.constant(1)), t)