def compile_or_get_cached(backend, computation, compile_options): # Avoid import cycle between jax and jax.experimental from jax.experimental.compilation_cache import compilation_cache as cc if isinstance(computation, ir.Module): sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value computation = mlir.module_to_string(computation) else: module_name = computation.name() # Persistent compilation cache only implemented on TPU. # TODO(skye): add warning when initializing cache on unsupported default platform if cc.is_initialized() and backend.platform == 'tpu': cached_executable = cc.get_executable(computation, compile_options, backend) if cached_executable is not None: logging.info('Persistent compilation cache hit for %s.', module_name) return cached_executable else: compiled = backend_compile(backend, computation, compile_options) cc.put_executable(module_name, computation, compile_options, compiled, backend) return compiled if FLAGS.jax_dump_ir_to: ir_str = (computation if isinstance(computation, str) else computation.as_hlo_text()) _dump_ir_to_file(module_name, ir_str) return backend_compile(backend, computation, compile_options)
def hlo(self) -> xc.XlaComputation: if self.is_trivial(): raise ValueError("A trivial computation has no HLO") if isinstance(self._hlo, xc.XlaComputation): return self._hlo return xe.mlir.mlir_module_to_xla_computation( mlir.module_to_string(self._hlo), use_tuple_args=self.compile_args["tuple_args"])
def compile_or_get_cached(backend, computation, compile_options): # Avoid import cycle between jax and jax.experimental from jax.experimental.compilation_cache import compilation_cache as cc if isinstance(computation, ir.Module): sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value # Convert ir.Module to str representation (the default), unless the # back-end expliclity flags the ability to handle a module directly # (avoiding the overhead of back and forth conversions) if getattr(backend, "needs_str_ir", True): computation = mlir.module_to_string(computation) else: module_name = computation.name() # Persistent compilation cache only implemented on TPU. # TODO(skye): add warning when initializing cache on unsupported default platform if cc.is_initialized() and backend.platform == 'tpu': cached_executable = cc.get_executable(computation, compile_options, backend) if cached_executable is not None: logging.info('Persistent compilation cache hit for %s.', module_name) return cached_executable else: compiled = backend_compile(backend, computation, compile_options) cc.put_executable(module_name, computation, compile_options, compiled, backend) return compiled if FLAGS.jax_dump_ir_to: if isinstance(computation, xc.XlaComputation): ir_str = computation.as_hlo_text() elif isinstance(computation, ir.Module): ir_str = mlir.module_to_string(computation) else: assert isinstance(computation, str) ir_str = computation _dump_ir_to_file(module_name, ir_str) return backend_compile(backend, computation, compile_options)
def _sharded_callable( fun: lu.WrappedFun, nparts: Optional[int], in_parts: Tuple[pxla.PartitionsOrReplicated, ...], out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]], local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]], local_out_parts_thunk: Callable[[], Optional[Tuple[ pxla.PartitionsOrReplicated, ...]]], local_nparts: Optional[int], name: str, *abstract_args): nrep = 1 if local_in_parts is None: local_in_parts = in_parts global_abstract_args = [ pxla.get_global_aval(arg, parts, lparts) for arg, parts, lparts in safe_zip( abstract_args, in_parts, local_in_parts) ] if logging.vlog_is_on(2): logging.vlog(2, "abstract_args: %s", abstract_args) logging.vlog(2, "global_abstract_args: %s", global_abstract_args) logging.vlog(2, "in_parts: %s", in_parts) logging.vlog(2, "local_in_parts: %s", local_in_parts) jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( fun, global_abstract_args) platform = xb.get_backend().platform nparts = pxla.reconcile_num_partitions(jaxpr, nparts) assert nparts is not None if nparts > xb.device_count(): raise ValueError( f"sharded_jit computation requires {nparts} devices, " f"but only {xb.device_count()} devices are available.") if xb.local_device_count() < nparts < xb.device_count(): raise NotImplementedError( f"sharded_jit across multiple hosts must use all available devices. " f"Got {nparts} out of {xb.device_count()} requested devices " f"(local device count: {xb.local_device_count()})") if local_nparts is None: if nparts > xb.local_device_count(): raise ValueError( "Specify 'local_nparts' when using cross-process sharded_jit " "and all inputs and outputs are replicated.") else: local_nparts = nparts if local_nparts > xb.local_device_count(): raise ValueError( f"sharded_jit computation requires {local_nparts} local devices, " f"but only {xb.local_device_count()} local devices are available.") if logging.vlog_is_on(2): logging.vlog(2, "nparts: %d local_nparts: %d", nparts, local_nparts) out_parts = out_parts_thunk() local_out_parts = local_out_parts_thunk() if local_out_parts is None: local_out_parts = out_parts if logging.vlog_is_on(2): logging.vlog(2, "out_parts: %s", out_parts) logging.vlog(2, "local_out_parts: %s", local_out_parts) local_out_avals = [ pxla.get_local_aval(out, parts, lparts) for out, parts, lparts in safe_zip( global_out_avals, out_parts, local_out_parts) ] log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, "Compiling %s for %d devices with args %s.", fun.__name__, nparts, global_abstract_args) axis_env = xla.AxisEnv(nrep, (), ()) unordered_effects = [ eff for eff in jaxpr.effects if eff not in core.ordered_effects ] ordered_effects = [ eff for eff in jaxpr.effects if eff in core.ordered_effects ] module, _ = mlir.lower_jaxpr_to_module( f"spjit_{fun.__name__}", core.ClosedJaxpr(jaxpr, consts), unordered_effects, ordered_effects, platform=platform, axis_context=mlir.ReplicaAxisContext(axis_env), name_stack=new_name_stack(wrap_name(name, "sharded_jit")), donated_args=[False] * len(in_parts), arg_shardings=safe_map(xla.sharding_to_proto, in_parts), result_shardings=safe_map(xla.sharding_to_proto, out_parts)) built = xc._xla.mlir.mlir_module_to_xla_computation( mlir.module_to_string(module), use_tuple_args=False, return_tuple=True) if nparts <= xb.local_device_count(): devices = xb.local_devices()[:nparts] else: assert nparts == xb.device_count() devices = xb.devices() device_assignment = np.array([[d for d in devices]]) device_assignment = np.reshape(device_assignment, (-1, nparts)) # device_assignment = None # TODO(skye): replace with default device assignment? compiled = dispatch.backend_compile( xb.get_backend(), built, xb.get_compile_options(nrep, nparts, device_assignment)) input_specs = [ pxla.partitioned_sharding_spec(local_nparts, parts, aval) for parts, aval in zip(local_in_parts, abstract_args) ] input_indices = [ pxla.spec_to_indices(aval.shape, spec) if spec is not None else None for aval, spec in zip(abstract_args, input_specs) ] handle_args = partial(pxla.shard_args, compiled.local_devices(), input_indices) handle_outs = _avals_to_results_handler( nrep, local_nparts, # type: ignore local_out_parts, local_out_avals) return partial(_execute_spatially_partitioned, compiled, handle_args, handle_outs)
def mhlo(self) -> str: if self.is_trivial(): raise ValueError("A trivial computation has no MHLO") if isinstance(self._hlo, xc.XlaComputation): return xe.mlir.xla_computation_to_mlir_module(self._hlo) return mlir.module_to_string(self._hlo)