Exemple #1
0
    def test_jit_on_nondefault_backend(self):
        cpus = api.devices("cpu")
        self.assertNotEmpty(cpus)

        # Since we are not on CPU, some other backend will be the default
        default_dev = api.devices()[0]
        self.assertNotEqual(default_dev.platform, "cpu")

        data_on_cpu = api.device_put(1, device=cpus[0])
        self.assertEqual(data_on_cpu.device_buffer.device(), cpus[0])

        def my_sin(x):
            return jnp.sin(x)

        # jit without any device spec follows the data
        result1 = api.jit(my_sin)(2)
        self.assertEqual(result1.device_buffer.device(), default_dev)
        result2 = api.jit(my_sin)(data_on_cpu)
        self.assertEqual(result2.device_buffer.device(), cpus[0])

        # jit with `device` spec places the data on the specified device
        result3 = api.jit(my_sin, device=cpus[0])(2)
        self.assertEqual(result3.device_buffer.device(), cpus[0])

        # jit with `backend` spec places the data on the specified backend
        result4 = api.jit(my_sin, backend="cpu")(2)
        self.assertEqual(result4.device_buffer.device(), cpus[0])
Exemple #2
0
def create_global_mesh(mesh_shape, axis_names):
    size = prod(mesh_shape)
    if len(api.devices()) < size:
        raise unittest.SkipTest(f"Test requires {size} global devices.")
    devices = sorted(api.devices(), key=lambda d: d.id)
    mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
    global_mesh = Mesh(mesh_devices, axis_names)
    return global_mesh
Exemple #3
0
    def test_closed_over_values_device_placement(self):
        # see https://github.com/google/jax/issues/1431
        def f():
            return jnp.add(3., 4.)

        self.assertNotEqual(
            api.jit(f)().device_buffer.device(),
            api.devices('cpu')[0])
        self.assertEqual(
            api.jit(f, backend='cpu')().device_buffer.device(),
            api.devices('cpu')[0])
Exemple #4
0
    def testJitCpu(self):
        @partial(api.jit, backend='cpu')
        def get_arr(scale):
            return scale + jnp.ones((2, 2))

        x = get_arr(0.1)

        a = x / x.shape[0]
        b = x + jnp.ones_like(x)
        c = x + jnp.eye(2)

        self.assertEqual(a.device_buffer.device(), api.devices('cpu')[0])
        self.assertEqual(b.device_buffer.device(), api.devices('cpu')[0])
        self.assertEqual(c.device_buffer.device(), api.devices('cpu')[0])
Exemple #5
0
    def test_sum(self):
        # https://github.com/google/jax/issues/2905
        cpus = api.devices("cpu")

        x = api.device_put(np.ones(2), cpus[0])
        y = x.sum()
        self.assertEqual(y.device_buffer.device(), cpus[0])