def test_jit_on_nondefault_backend(self): cpus = api.devices("cpu") self.assertNotEmpty(cpus) # Since we are not on CPU, some other backend will be the default default_dev = api.devices()[0] self.assertNotEqual(default_dev.platform, "cpu") data_on_cpu = api.device_put(1, device=cpus[0]) self.assertEqual(data_on_cpu.device_buffer.device(), cpus[0]) def my_sin(x): return jnp.sin(x) # jit without any device spec follows the data result1 = api.jit(my_sin)(2) self.assertEqual(result1.device_buffer.device(), default_dev) result2 = api.jit(my_sin)(data_on_cpu) self.assertEqual(result2.device_buffer.device(), cpus[0]) # jit with `device` spec places the data on the specified device result3 = api.jit(my_sin, device=cpus[0])(2) self.assertEqual(result3.device_buffer.device(), cpus[0]) # jit with `backend` spec places the data on the specified backend result4 = api.jit(my_sin, backend="cpu")(2) self.assertEqual(result4.device_buffer.device(), cpus[0])
def create_global_mesh(mesh_shape, axis_names): size = prod(mesh_shape) if len(api.devices()) < size: raise unittest.SkipTest(f"Test requires {size} global devices.") devices = sorted(api.devices(), key=lambda d: d.id) mesh_devices = np.array(devices[:size]).reshape(mesh_shape) global_mesh = Mesh(mesh_devices, axis_names) return global_mesh
def test_closed_over_values_device_placement(self): # see https://github.com/google/jax/issues/1431 def f(): return jnp.add(3., 4.) self.assertNotEqual( api.jit(f)().device_buffer.device(), api.devices('cpu')[0]) self.assertEqual( api.jit(f, backend='cpu')().device_buffer.device(), api.devices('cpu')[0])
def testJitCpu(self): @partial(api.jit, backend='cpu') def get_arr(scale): return scale + jnp.ones((2, 2)) x = get_arr(0.1) a = x / x.shape[0] b = x + jnp.ones_like(x) c = x + jnp.eye(2) self.assertEqual(a.device_buffer.device(), api.devices('cpu')[0]) self.assertEqual(b.device_buffer.device(), api.devices('cpu')[0]) self.assertEqual(c.device_buffer.device(), api.devices('cpu')[0])
def test_sum(self): # https://github.com/google/jax/issues/2905 cpus = api.devices("cpu") x = api.device_put(np.ones(2), cpus[0]) y = x.sum() self.assertEqual(y.device_buffer.device(), cpus[0])