def test_get_no_executable(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) computation = jax.xla_computation(lambda x, y: x + y)(1, 1) compile_options = jax.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) self.assertEqual(cc.get_executable(computation, compile_options), None)
def test_jit(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) f = jit(lambda x: x * x) f(1) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 1) f(1.0) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 2)
def test_pmap(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i') x = np.arange(jax.device_count(), dtype=np.int64) f(x) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 1) x = np.arange(jax.device_count(), dtype=np.float32) f(x) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 2)
def test_diff_executables(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1) computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2) compile_options = jax.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) backend = jax.lib.xla_bridge.get_backend() executable1 = backend.compile(computation1, compile_options) executable2 = backend.compile(computation2, compile_options) cc.put_executable(computation1, compile_options, executable1) cc.put_executable(computation2, compile_options, executable2) self.assertNotEqual( cc.get_executable(computation1, compile_options), cc.get_executable(computation2, compile_options))
def test_put_executable(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) computation = jax.xla_computation(lambda x, y: x + y)(1, 1) compile_options = jax._src.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) backend = jax._src.lib.xla_bridge.get_backend() executable = backend.compile(computation, compile_options) cc.put_executable("alambda", computation, compile_options, executable, backend) deserialized_executable = cc.get_executable(computation, compile_options, backend) inputs_to_executable = (np.array(1, dtype=np.int32), np.array(2, dtype=np.int32)) expected = jax._src.lib.xla_client.execute_with_python_values(executable, inputs_to_executable, backend) actual = jax._src.lib.xla_client.execute_with_python_values(deserialized_executable, inputs_to_executable, backend) self.assertEqual(expected, actual)
def test_xmap(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) def f(x): return x * 2 devices = np.array(jax.local_devices()[:2]) if devices.size < 2: raise SkipTest("Test requires 2 devices") x = np.arange(8, dtype=np.int64).reshape((2, 2, 2)) xmap(f, in_axes=['a', ...], out_axes=['a', ...], axis_resources={'a': 'x'})(x) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 1) x = np.arange(8, dtype=np.float32).reshape((2, 2, 2)) xmap(f, in_axes=['a', ...], out_axes=['a', ...], axis_resources={'a': 'x'})(x) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 2)
def test_pjit(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) @partial(pjit, in_axis_resources=(P('x'), P('x')), out_axis_resources=None) def f(x, y): return x + y shape = (8, 8) x = np.arange(prod(shape), dtype=np.int64).reshape(shape) f(x, x + 1) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 1) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) f(x, x + 1) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 2)