Exemplo n.º 1
0
 def test_diff_executables(self):
     with tempfile.TemporaryDirectory() as tmpdir:
         cc.initialize_cache(tmpdir)
         computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
         computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
         compile_options = jax.lib.xla_bridge.get_compile_options(
             num_replicas=1, num_partitions=1)
         backend = jax.lib.xla_bridge.get_backend()
         executable1 = backend.compile(computation1, compile_options)
         executable2 = backend.compile(computation2, compile_options)
         cc.put_executable(computation1, compile_options, executable1)
         cc.put_executable(computation2, compile_options, executable2)
         self.assertNotEqual(
             cc.get_executable(computation1, compile_options),
             cc.get_executable(computation2, compile_options))
Exemplo n.º 2
0
def compile_or_get_cached(backend, computation, compile_options):
  # Avoid import cycle between jax and jax.experimental
  from jax.experimental.compilation_cache import compilation_cache as cc

  if isinstance(computation, ir.Module):
    sym_name = computation.operation.attributes['sym_name']
    module_name = ir.StringAttr(sym_name).value
    computation = mlir.module_to_string(computation)
  else:
    module_name = computation.name()

  # Persistent compilation cache only implemented on TPU.
  # TODO(skye): add warning when initializing cache on unsupported default platform
  if cc.is_initialized() and backend.platform == 'tpu':
    cached_executable = cc.get_executable(computation, compile_options, backend)
    if cached_executable is not None:
      logging.info('Persistent compilation cache hit for %s.', module_name)
      return cached_executable
    else:
      compiled = backend_compile(backend, computation, compile_options)
      cc.put_executable(module_name, computation, compile_options, compiled,
                        backend)
      return compiled

  if FLAGS.jax_dump_ir_to:
    ir_str = (computation if isinstance(computation, str)
              else computation.as_hlo_text())
    _dump_ir_to_file(module_name, ir_str)
  return backend_compile(backend, computation, compile_options)
Exemplo n.º 3
0
 def test_get_no_executable(self):
     with tempfile.TemporaryDirectory() as tmpdir:
         cc.initialize_cache(tmpdir)
         computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
         compile_options = jax.lib.xla_bridge.get_compile_options(
             num_replicas=1, num_partitions=1)
         self.assertEqual(cc.get_executable(computation, compile_options),
                          None)
Exemplo n.º 4
0
def compile_or_get_cached(backend, computation, compile_options):
    # Avoid import cycle between jax and jax.experimental
    from jax.experimental.compilation_cache import compilation_cache as cc
    # Persistent compilation cache only implemented on TPU.
    # TODO(skye): add warning when initializing cache on unsupported default platform
    if cc.is_initialized() and backend.platform == 'tpu':
        cached_executable = cc.get_executable(computation, compile_options, backend)
        if cached_executable is not None:
            logging.info('Persistent compilation cache hit')
            return cached_executable
        else:
            compiled = backend_compile(backend, computation, compile_options)
            cc.put_executable(computation, compile_options, compiled, backend)
            return compiled
    return backend_compile(backend, computation, compile_options)
Exemplo n.º 5
0
 def test_put_executable(self):
   with tempfile.TemporaryDirectory() as tmpdir:
     cc.initialize_cache(tmpdir)
     computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
     compile_options = jax._src.lib.xla_bridge.get_compile_options(
         num_replicas=1, num_partitions=1)
     backend = jax._src.lib.xla_bridge.get_backend()
     executable = backend.compile(computation, compile_options)
     cc.put_executable("alambda", computation, compile_options, executable,
                       backend)
     deserialized_executable = cc.get_executable(computation, compile_options, backend)
     inputs_to_executable = (np.array(1, dtype=np.int32), np.array(2, dtype=np.int32))
     expected = jax._src.lib.xla_client.execute_with_python_values(executable, inputs_to_executable, backend)
     actual = jax._src.lib.xla_client.execute_with_python_values(deserialized_executable, inputs_to_executable, backend)
     self.assertEqual(expected, actual)
Exemplo n.º 6
0
def compile_or_get_cached(backend, computation, compile_options):
    # Avoid import cycle between jax and jax.experimental
    from jax.experimental.compilation_cache import compilation_cache as cc

    if isinstance(computation, ir.Module):
        sym_name = computation.operation.attributes['sym_name']
        module_name = ir.StringAttr(sym_name).value
        # Convert ir.Module to str representation (the default), unless the
        # back-end expliclity flags the ability to handle a module directly
        # (avoiding the overhead of back and forth conversions)
        if getattr(backend, "needs_str_ir", True):
            computation = mlir.module_to_string(computation)
    else:
        module_name = computation.name()

    # Persistent compilation cache only implemented on TPU.
    # TODO(skye): add warning when initializing cache on unsupported default platform
    if cc.is_initialized() and backend.platform == 'tpu':
        cached_executable = cc.get_executable(computation, compile_options,
                                              backend)
        if cached_executable is not None:
            logging.info('Persistent compilation cache hit for %s.',
                         module_name)
            return cached_executable
        else:
            compiled = backend_compile(backend, computation, compile_options)
            cc.put_executable(module_name, computation, compile_options,
                              compiled, backend)
            return compiled

    if FLAGS.jax_dump_ir_to:
        if isinstance(computation, xc.XlaComputation):
            ir_str = computation.as_hlo_text()
        elif isinstance(computation, ir.Module):
            ir_str = mlir.module_to_string(computation)
        else:
            assert isinstance(computation, str)
            ir_str = computation
        _dump_ir_to_file(module_name, ir_str)
    return backend_compile(backend, computation, compile_options)