def test_diff_executables(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1) computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2) compile_options = jax.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) backend = jax.lib.xla_bridge.get_backend() executable1 = backend.compile(computation1, compile_options) executable2 = backend.compile(computation2, compile_options) cc.put_executable(computation1, compile_options, executable1) cc.put_executable(computation2, compile_options, executable2) self.assertNotEqual( cc.get_executable(computation1, compile_options), cc.get_executable(computation2, compile_options))
def compile_or_get_cached(backend, computation, compile_options): # Avoid import cycle between jax and jax.experimental from jax.experimental.compilation_cache import compilation_cache as cc if isinstance(computation, ir.Module): sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value computation = mlir.module_to_string(computation) else: module_name = computation.name() # Persistent compilation cache only implemented on TPU. # TODO(skye): add warning when initializing cache on unsupported default platform if cc.is_initialized() and backend.platform == 'tpu': cached_executable = cc.get_executable(computation, compile_options, backend) if cached_executable is not None: logging.info('Persistent compilation cache hit for %s.', module_name) return cached_executable else: compiled = backend_compile(backend, computation, compile_options) cc.put_executable(module_name, computation, compile_options, compiled, backend) return compiled if FLAGS.jax_dump_ir_to: ir_str = (computation if isinstance(computation, str) else computation.as_hlo_text()) _dump_ir_to_file(module_name, ir_str) return backend_compile(backend, computation, compile_options)
def test_get_no_executable(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) computation = jax.xla_computation(lambda x, y: x + y)(1, 1) compile_options = jax.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) self.assertEqual(cc.get_executable(computation, compile_options), None)
def compile_or_get_cached(backend, computation, compile_options): # Avoid import cycle between jax and jax.experimental from jax.experimental.compilation_cache import compilation_cache as cc # Persistent compilation cache only implemented on TPU. # TODO(skye): add warning when initializing cache on unsupported default platform if cc.is_initialized() and backend.platform == 'tpu': cached_executable = cc.get_executable(computation, compile_options, backend) if cached_executable is not None: logging.info('Persistent compilation cache hit') return cached_executable else: compiled = backend_compile(backend, computation, compile_options) cc.put_executable(computation, compile_options, compiled, backend) return compiled return backend_compile(backend, computation, compile_options)
def test_put_executable(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) computation = jax.xla_computation(lambda x, y: x + y)(1, 1) compile_options = jax._src.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) backend = jax._src.lib.xla_bridge.get_backend() executable = backend.compile(computation, compile_options) cc.put_executable("alambda", computation, compile_options, executable, backend) deserialized_executable = cc.get_executable(computation, compile_options, backend) inputs_to_executable = (np.array(1, dtype=np.int32), np.array(2, dtype=np.int32)) expected = jax._src.lib.xla_client.execute_with_python_values(executable, inputs_to_executable, backend) actual = jax._src.lib.xla_client.execute_with_python_values(deserialized_executable, inputs_to_executable, backend) self.assertEqual(expected, actual)
def compile_or_get_cached(backend, computation, compile_options): # Avoid import cycle between jax and jax.experimental from jax.experimental.compilation_cache import compilation_cache as cc if isinstance(computation, ir.Module): sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value # Convert ir.Module to str representation (the default), unless the # back-end expliclity flags the ability to handle a module directly # (avoiding the overhead of back and forth conversions) if getattr(backend, "needs_str_ir", True): computation = mlir.module_to_string(computation) else: module_name = computation.name() # Persistent compilation cache only implemented on TPU. # TODO(skye): add warning when initializing cache on unsupported default platform if cc.is_initialized() and backend.platform == 'tpu': cached_executable = cc.get_executable(computation, compile_options, backend) if cached_executable is not None: logging.info('Persistent compilation cache hit for %s.', module_name) return cached_executable else: compiled = backend_compile(backend, computation, compile_options) cc.put_executable(module_name, computation, compile_options, compiled, backend) return compiled if FLAGS.jax_dump_ir_to: if isinstance(computation, xc.XlaComputation): ir_str = computation.as_hlo_text() elif isinstance(computation, ir.Module): ir_str = mlir.module_to_string(computation) else: assert isinstance(computation, str) ir_str = computation _dump_ir_to_file(module_name, ir_str) return backend_compile(backend, computation, compile_options)