def cpu_fallback_warning(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") xb.get_backend() self.assertLen(w, 1) msg = str(w[-1].message) self.assertIn("No GPU/TPU found, falling back to CPU", msg)
def start_trace(log_dir, create_perfetto_link: bool = False): """Starts a profiler trace. The trace will capture CPU, GPU, and/or TPU activity, including Python functions and JAX on-device operations. Use ``stop_trace()`` to end the trace and save the results to ``log_dir``. The resulting trace can be viewed with TensorBoard. Note that TensorBoard doesn't need to be running when collecting the trace. Only once trace may be collected a time. A RuntimeError will be raised if ``start_trace()`` is called while another trace is running. Args: log_dir: The directory to save the profiler trace to (usually the TensorBoard log directory). create_perfetto_link: A boolean which, if true, creates and prints link to the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will block until the link is opened and Perfetto loads the trace. """ with _profile_state.lock: if _profile_state.profile_session is not None: raise RuntimeError("Profile has already been started. " "Only one profile may be run at a time.") # Make sure backends are initialized before creating a profiler # session. Otherwise on Cloud TPU, libtpu may not be initialized before # creating the tracer, which will cause the TPU tracer initialization to # fail and no TPU operations will be included in the profile. xla_bridge.get_backend() _profile_state.profile_session = xla_client.profiler.ProfilerSession() _profile_state.create_perfetto_link = create_perfetto_link _profile_state.log_dir = log_dir
def test_factory_returns_none(self): xb.register_backend_factory("none", lambda: None, priority=10) default_backend = xb.get_backend() self.assertEqual(default_backend.platform, "cpu") with self.assertRaisesRegex( RuntimeError, "Backend 'none' failed to initialize: " "Could not initialize backend 'none'"): xb.get_backend("none")
def test_backend_init_error(self): def factory(): raise RuntimeError("I'm not a real backend") xb.register_backend_factory("error", factory, priority=10) # No error raised if there's a fallback backend. default_backend = xb.get_backend() self.assertEqual(default_backend.platform, "cpu") with self.assertRaisesRegex(RuntimeError, "I'm not a real backend"): xb.get_backend("error")
def device_memory_profile(backend: Optional[str] = None) -> bytes: """Captures a JAX device memory profile as ``pprof``-format protocol buffer. A device memory profile is a snapshot of the state of memory, that describes the JAX :class:`jax.DeviceArray` and executable objects present in memory and their allocation sites. For more information how to use the device memory profiler, see :doc:`/device_memory_profiling`. The profiling system works by instrumenting JAX on-device allocations, capturing a Python stack trace for each allocation. The instrumentation is always enabled; :func:`device_memory_profile` provides an API to capture it. The output of :func:`device_memory_profile` is a binary protocol buffer that can be interpreted and visualized by the `pprof tool <https://github.com/google/pprof>`_. Args: backend: optional; the name of the JAX backend for which the device memory profile should be collected. Returns: A byte string containing a binary `pprof`-format protocol buffer. """ return xla_client.heap_profile(xla_bridge.get_backend(backend))
def test_specific_platform(self): self._register_factory("platform_A", 20) self._register_factory("platform_B", 10) backend = xb.get_backend("platform_B") self.assertEqual(backend.platform, "platform_B") # All backends initialized. self.assertEqual(len(xb._backends), len(xb._backend_factories))
def from_dlpack(dlpack): """Returns a ``DeviceArray`` representation of a DLPack tensor. The returned ``DeviceArray`` shares memory with ``dlpack``. Args: dlpack: a DLPack tensor, on either CPU or GPU. """ cpu_backend = xla_bridge.get_backend("cpu") try: gpu_backend = xla_bridge.get_backend("gpu") except RuntimeError: gpu_backend = None buf = xla_client._xla.dlpack_managed_tensor_to_buffer( dlpack, cpu_backend, gpu_backend) xla_shape = buf.xla_shape() assert not xla_shape.is_tuple() aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype()) return device_array.make_device_array(aval, buf.device(), buf) # pytype: disable=attribute-error
def testTorchToJaxFailure(self): x = torch.arange(6).reshape((2, 3)) y = torch.utils.dlpack.to_dlpack(x[:, :2]) backend = xla_bridge.get_backend() client = getattr(backend, "client", backend) regex_str = (r'UNIMPLEMENTED: Only DLPack tensors with trivial \(compact\) ' r'striding are supported') with self.assertRaisesRegex(RuntimeError, regex_str): xla_client._xla.dlpack_managed_tensor_to_buffer( y, client)
def test_jax_platforms_flag(self): self._register_factory("platform_A", 20) self._register_factory("platform_B", 10) orig_jax_platforms = config._read("jax_platforms") try: config.FLAGS.jax_platforms = "cpu,platform_A" backend = xb.get_backend() self.assertEqual(backend.platform, "cpu") # Only specified backends initialized. self.assertEqual(len(xb._backends), 2) backend = xb.get_backend("platform_A") self.assertEqual(backend.platform, "platform_A") with self.assertRaisesRegex(RuntimeError, "Unknown backend platform_B"): backend = xb.get_backend("platform_B") finally: config.FLAGS.jax_platforms = orig_jax_platforms
def start_server(port: int): """Starts the profiler server on port `port`. Using the "TensorFlow profiler" feature in `TensorBoard <https://www.tensorflow.org/tensorboard>`_ 2.2 or newer, you can connect to the profiler server and sample execution traces that show CPU, GPU, and/or TPU device activity. """ global _profiler_server if _profiler_server is not None: raise ValueError("Only one profiler server can be active at a time.") # Make sure backends are initialized before creating a profiler # session. Otherwise on Cloud TPU, libtpu may not be initialized before # creating the tracer, which will cause the TPU tracer initialization to # fail and no TPU operations will be included in the profile. # NOTE(skyewm): I'm not sure this is necessary for start_server (is definitely # is for start_trace), but I'm putting it here to be safe. xla_bridge.get_backend() _profiler_server = xla_client.profiler.start_server(port) return _profiler_server
def _xla_callable_device(nreps, backend, device, arg_devices): if nreps > 1: if device is not None or backend is not None: raise ValueError(f"can't specify device or backend for jit-of-pmap, " f"got device={device} and backend={backend}") return None else: if device is None and backend is None: return _device_from_arg_devices(arg_devices) elif device is not None and backend is None: return device elif device is None and backend is not None: return xb.get_backend(backend).get_default_device_assignment(1)[0] else: assert False # Unreachable given the error check in _xla_callable
def _xla_callable_device(nreps, backend, device, arg_devices) -> Optional[Device]: if nreps > 1: if device is not None or backend is not None: raise ValueError( f"can't specify device or backend for jit-of-pmap, " f"got device={device} and backend={backend}") return None else: # TODO(skye): dedup with C++ jit logic for determining jit device? if device is not None: assert backend is None return device if backend is not None: return xb.get_backend(backend).get_default_device_assignment(1)[0] arg_device = _device_from_arg_devices(arg_devices) if arg_device is not None: return arg_device return config.jax_default_device
def test_no_devices(self): self._register_factory("no_devices", -10, device_count=0) default_backend = xb.get_backend() self.assertEqual(default_backend.platform, "cpu") with self.assertRaisesRegex( RuntimeError, "Backend 'no_devices' failed to initialize: " "Backend 'no_devices' provides no devices."): xb.get_backend("no_devices") self._reset_backend_state() self._register_factory("no_devices2", 10, device_count=0) default_backend = xb.get_backend() self.assertEqual(default_backend.platform, "cpu") with self.assertRaisesRegex( RuntimeError, "Backend 'no_devices2' failed to initialize: " "Backend 'no_devices2' provides no devices."): xb.get_backend("no_devices2")
def _sharded_callable( fun: lu.WrappedFun, nparts: Optional[int], in_parts: Tuple[pxla.PartitionsOrReplicated, ...], out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]], local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]], local_out_parts_thunk: Callable[[], Optional[Tuple[ pxla.PartitionsOrReplicated, ...]]], local_nparts: Optional[int], name: str, *abstract_args): nrep = 1 if local_in_parts is None: local_in_parts = in_parts global_abstract_args = [ pxla.get_global_aval(arg, parts, lparts) for arg, parts, lparts in safe_zip( abstract_args, in_parts, local_in_parts) ] if logging.vlog_is_on(2): logging.vlog(2, "abstract_args: %s", abstract_args) logging.vlog(2, "global_abstract_args: %s", global_abstract_args) logging.vlog(2, "in_parts: %s", in_parts) logging.vlog(2, "local_in_parts: %s", local_in_parts) jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( fun, global_abstract_args) platform = xb.get_backend().platform nparts = pxla.reconcile_num_partitions(jaxpr, nparts) assert nparts is not None if nparts > xb.device_count(): raise ValueError( f"sharded_jit computation requires {nparts} devices, " f"but only {xb.device_count()} devices are available.") if xb.local_device_count() < nparts < xb.device_count(): raise NotImplementedError( f"sharded_jit across multiple hosts must use all available devices. " f"Got {nparts} out of {xb.device_count()} requested devices " f"(local device count: {xb.local_device_count()})") if local_nparts is None: if nparts > xb.local_device_count(): raise ValueError( "Specify 'local_nparts' when using cross-process sharded_jit " "and all inputs and outputs are replicated.") else: local_nparts = nparts if local_nparts > xb.local_device_count(): raise ValueError( f"sharded_jit computation requires {local_nparts} local devices, " f"but only {xb.local_device_count()} local devices are available.") if logging.vlog_is_on(2): logging.vlog(2, "nparts: %d local_nparts: %d", nparts, local_nparts) out_parts = out_parts_thunk() local_out_parts = local_out_parts_thunk() if local_out_parts is None: local_out_parts = out_parts if logging.vlog_is_on(2): logging.vlog(2, "out_parts: %s", out_parts) logging.vlog(2, "local_out_parts: %s", local_out_parts) local_out_avals = [ pxla.get_local_aval(out, parts, lparts) for out, parts, lparts in safe_zip( global_out_avals, out_parts, local_out_parts) ] log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, "Compiling %s for %d devices with args %s.", fun.__name__, nparts, global_abstract_args) axis_env = xla.AxisEnv(nrep, (), ()) unordered_effects = [ eff for eff in jaxpr.effects if eff not in core.ordered_effects ] ordered_effects = [ eff for eff in jaxpr.effects if eff in core.ordered_effects ] module, _ = mlir.lower_jaxpr_to_module( f"spjit_{fun.__name__}", core.ClosedJaxpr(jaxpr, consts), unordered_effects, ordered_effects, platform=platform, axis_context=mlir.ReplicaAxisContext(axis_env), name_stack=new_name_stack(wrap_name(name, "sharded_jit")), donated_args=[False] * len(in_parts), arg_shardings=safe_map(xla.sharding_to_proto, in_parts), result_shardings=safe_map(xla.sharding_to_proto, out_parts)) built = xc._xla.mlir.mlir_module_to_xla_computation( mlir.module_to_string(module), use_tuple_args=False, return_tuple=True) if nparts <= xb.local_device_count(): devices = xb.local_devices()[:nparts] else: assert nparts == xb.device_count() devices = xb.devices() device_assignment = np.array([[d for d in devices]]) device_assignment = np.reshape(device_assignment, (-1, nparts)) # device_assignment = None # TODO(skye): replace with default device assignment? compiled = dispatch.backend_compile( xb.get_backend(), built, xb.get_compile_options(nrep, nparts, device_assignment)) input_specs = [ pxla.partitioned_sharding_spec(local_nparts, parts, aval) for parts, aval in zip(local_in_parts, abstract_args) ] input_indices = [ pxla.spec_to_indices(aval.shape, spec) if spec is not None else None for aval, spec in zip(abstract_args, input_specs) ] handle_args = partial(pxla.shard_args, compiled.local_devices(), input_indices) handle_outs = _avals_to_results_handler( nrep, local_nparts, # type: ignore local_out_parts, local_out_avals) return partial(_execute_spatially_partitioned, compiled, handle_args, handle_outs)
def device_under_test(): return getattr(FLAGS, 'jax_test_dut', None) or xla_bridge.get_backend().platform
def is_device_rocm(): return xla_bridge.get_backend().platform_version.startswith('rocm')
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *arg_specs): if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " "got device={} and backend={}".format(device, backend)) abstract_args, arg_devices = util.unzip2(arg_specs) with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " "for jit in {elapsed_time} sec"): jaxpr, out_avals, consts = pe.trace_to_jaxpr_final( fun, abstract_args, pe.debug_info_final(fun, "jit")) if any(isinstance(c, core.Tracer) for c in consts): raise UnexpectedTracerError("Encountered an unexpected tracer.") jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr) consts = [c for i, c in enumerate(consts) if i in kept_const_idx] pruned_arg_specs = (a for i, a in enumerate(arg_specs) if i in kept_var_idx) abstract_args, arg_devices = util.unzip2(pruned_arg_specs) donated_invars = [ x for i, x in enumerate(donated_invars) if i in kept_var_idx ] map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr))) jaxpr = apply_outfeed_rewriter(jaxpr) nreps = jaxpr_replicas(jaxpr) device = _xla_callable_device(nreps, backend, device, arg_devices) backend = xb.get_device_backend(device) if device else xb.get_backend(backend) # Computations that only produce constants and/or only rearrange their inputs, # which are often produced from partial evaluation, don't need compilation, # and don't need to evaluate their arguments. if not jaxpr.eqns: return XlaComputation( name, None, True, None, jaxpr=jaxpr, consts=consts, device=device, in_avals=abstract_args, out_avals=out_avals, kept_var_idx=kept_var_idx) if not _on_exit: log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG if len(abstract_args) > 10: msg = f"Compiling {fun.__name__} ({id(fun)}) for {len(abstract_args)} args." else: msg = f"Compiling {fun.__name__} ({id(fun)} for args {abstract_args}." logging.log(log_priority, msg) if nreps > 1: warnings.warn( f"The jitted function {name} includes a pmap. Using " "jit-of-pmap can lead to inefficient data movement, as the outer jit " "does not preserve sharded data representations and instead collects " "input and output arrays onto a single device. " "Consider removing the outer jit unless you know what you're doing. " "See https://github.com/google/jax/issues/2926.") if nreps > xb.device_count(backend): raise ValueError( f"compiling computation `{name}` that requires {nreps} replicas, but " f"only {xb.device_count(backend)} XLA devices are available.") if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)): raise NotImplementedError( "jit of multi-host pmap not implemented (and jit-of-pmap can cause " "extra data movement anyway, so maybe you don't want it after all).") # pass long arg lists as tuple for TPU tuple_args = len(abstract_args) > 100 axis_env = xla.AxisEnv(nreps, (), ()) name_stack = xla.new_name_stack(xla.wrap_name(name, 'jit')) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) module: Union[str, xc.XlaComputation] module_name = f"jit_{fun.__name__}" if config.jax_enable_mlir: module = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, backend.platform, mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars) else: module = xla.lower_jaxpr_to_xla_module( module_name, closed_jaxpr, backend.platform, axis_env, name_stack, tuple_args, donated_invars, replicated_args=None, arg_partitions=None, out_partitions=None) return XlaComputation( name, module, False, donated_invars, nreps=nreps, device=device, backend=backend, tuple_args=tuple_args, in_avals=abstract_args, out_avals=out_avals, kept_var_idx=kept_var_idx)
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, always_lower: bool, keep_unused: bool, *arg_specs): """Lower into XLA. Args: always_lower: If `True`, even trivial programs (not doing any computation such as lambda x: x) will be lowered into an XLA program. keep_unused: If `False` (the default), arguments that JAX determines to be unused by `fun` *may* be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If `True`, unused arguments will not be pruned. """ if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " "got device={} and backend={}".format( device, backend)) abstract_args, arg_devices = util.unzip2(arg_specs) if fun.in_type is not None: abstract_args, which_explicit = util.unzip2(fun.in_type) else: which_explicit = None with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " "for jit in {elapsed_time} sec"): jaxpr, out_avals, consts = pe.trace_to_jaxpr_final( fun, abstract_args, pe.debug_info_final(fun, "jit"), which_explicit) if any(isinstance(c, core.Tracer) for c in consts): raise UnexpectedTracerError("Encountered an unexpected tracer.") # TODO(mattjj): handle argument pruning w/ dynamic shapes if fun.in_type is None and not keep_unused: jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr) consts = [c for i, c in enumerate(consts) if i in kept_const_idx] abstract_args, arg_devices = util.unzip2( [a for i, a in enumerate(arg_specs) if i in kept_var_idx]) donated_invars = [ x for i, x in enumerate(donated_invars) if i in kept_var_idx ] del kept_const_idx else: kept_var_idx = set(range(len(abstract_args))) map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr))) jaxpr = apply_outfeed_rewriter(jaxpr) nreps = jaxpr_replicas(jaxpr) device = _xla_callable_device(nreps, backend, device, arg_devices) backend = xb.get_device_backend(device) if device else xb.get_backend( backend) if (config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr) and not _backend_supports_unbounded_dynamic_shapes(backend)): jaxpr, consts = pe.pad_jaxpr(jaxpr, consts) # Computations that only produce constants and/or only rearrange their inputs, # which are often produced from partial evaluation, don't need compilation, # and don't need to evaluate their arguments. if not jaxpr.eqns and not always_lower: return XlaComputation(name, None, True, None, None, jaxpr=jaxpr, consts=consts, device=device, in_avals=abstract_args, out_avals=out_avals, has_unordered_effects=False, ordered_effects=[], kept_var_idx=kept_var_idx, keepalive=None) if not _on_exit: log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG if len(abstract_args) > 10: msg = f"Compiling {fun.__name__} ({id(fun)}) for {len(abstract_args)} args." else: msg = f"Compiling {fun.__name__} ({id(fun)} for args {abstract_args}." logging.log(log_priority, msg) if nreps > 1: warnings.warn( f"The jitted function {name} includes a pmap. Using " "jit-of-pmap can lead to inefficient data movement, as the outer jit " "does not preserve sharded data representations and instead collects " "input and output arrays onto a single device. " "Consider removing the outer jit unless you know what you're doing. " "See https://github.com/google/jax/issues/2926.") if nreps > xb.device_count(backend): raise ValueError( f"compiling computation `{name}` that requires {nreps} replicas, but " f"only {xb.device_count(backend)} XLA devices are available.") if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)): raise NotImplementedError( "jit of multi-host pmap not implemented (and jit-of-pmap can cause " "extra data movement anyway, so maybe you don't want it after all)." ) # pass long arg lists as tuple for TPU tuple_args = len(abstract_args) > 100 axis_env = xla.AxisEnv(nreps, (), ()) name_stack = util.new_name_stack(util.wrap_name(name, 'jit')) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) module_name = f"jit_{fun.__name__}" unordered_effects = [ eff for eff in closed_jaxpr.effects if eff not in core.ordered_effects ] ordered_effects = [ eff for eff in closed_jaxpr.effects if eff in core.ordered_effects ] module, keepalive = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, unordered_effects, ordered_effects, backend.platform, mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars) return XlaComputation(name, module, False, donated_invars, which_explicit, nreps=nreps, device=device, backend=backend, tuple_args=tuple_args, in_avals=abstract_args, out_avals=out_avals, has_unordered_effects=bool(unordered_effects), ordered_effects=ordered_effects, kept_var_idx=kept_var_idx, keepalive=keepalive)
def is_device_cuda(): return xla_bridge.get_backend().platform_version.startswith('cuda')
def _sharded_callable( fun: lu.WrappedFun, nparts: Optional[int], in_parts: Tuple[pxla.PartitionsOrReplicated, ...], out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]], local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]], local_out_parts_thunk: Callable[[], Optional[Tuple[pxla.PartitionsOrReplicated, ...]]], local_nparts: Optional[int], name: str, *abstract_args): nrep = 1 if local_in_parts is None: local_in_parts = in_parts global_abstract_args = [pxla.get_global_aval(arg, parts, lparts) for arg, parts, lparts in safe_zip(abstract_args, in_parts, local_in_parts)] if logging.vlog_is_on(2): logging.vlog(2, "abstract_args: %s", abstract_args) logging.vlog(2, "global_abstract_args: %s", global_abstract_args) logging.vlog(2, "in_parts: %s", in_parts) logging.vlog(2, "local_in_parts: %s", local_in_parts) jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(fun, global_abstract_args) if xb.get_backend().platform not in ["tpu", "gpu"]: # TODO(skye): fall back to regular jit? raise ValueError("sharded_jit not supported for " + xb.get_backend().platform) nparts = pxla.reconcile_num_partitions(jaxpr, nparts) assert nparts is not None if nparts > xb.device_count(): raise ValueError( f"sharded_jit computation requires {nparts} devices, " f"but only {xb.device_count()} devices are available.") if xb.local_device_count() < nparts < xb.device_count(): raise NotImplementedError( f"sharded_jit across multiple hosts must use all available devices. " f"Got {nparts} out of {xb.device_count()} requested devices " f"(local device count: {xb.local_device_count()})") if local_nparts is None: if nparts > xb.local_device_count(): raise ValueError( "Specify 'local_nparts' when using cross-process sharded_jit " "and all inputs and outputs are replicated.") else: local_nparts = nparts if local_nparts > xb.local_device_count(): raise ValueError( f"sharded_jit computation requires {local_nparts} local devices, " f"but only {xb.local_device_count()} local devices are available.") if logging.vlog_is_on(2): logging.vlog(2, "nparts: %d local_nparts: %d", nparts, local_nparts) out_parts = out_parts_thunk() local_out_parts = local_out_parts_thunk() if local_out_parts is None: local_out_parts = out_parts if logging.vlog_is_on(2): logging.vlog(2, "out_parts: %s", out_parts) logging.vlog(2, "local_out_parts: %s", local_out_parts) local_out_avals = [pxla.get_local_aval(out, parts, lparts) for out, parts, lparts in safe_zip(global_out_avals, out_parts, local_out_parts)] log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, "Compiling %s for %d devices with args %s.", fun.__name__, nparts, global_abstract_args) c = xb.make_computation_builder("spjit_{}".format(fun.__name__)) xla_consts = _map(partial(xb.constant, c), consts) xla_args = _xla_sharded_args(c, global_abstract_args, in_parts) axis_env = xla.AxisEnv(nrep, (), ()) out_nodes = xla.jaxpr_subcomp( c, jaxpr, None, axis_env, xla_consts, extend_name_stack(wrap_name(name, "sharded_jit")), *xla_args) out_tuple = xb.with_sharding(c, out_parts, xops.Tuple, c, out_nodes) built = c.Build(out_tuple) if nparts <= xb.local_device_count(): devices = xb.local_devices()[:nparts] else: assert nparts == xb.device_count() devices = xb.devices() device_assignment = np.array([[d.id for d in devices]]) device_assignment = np.reshape(device_assignment, (-1, nparts)) # device_assignment = None # TODO(skye): replace with default device assignment? compiled = xla.backend_compile( xb.get_backend(), built, xb.get_compile_options(nrep, nparts, device_assignment)) input_specs = [ pxla.partitioned_sharding_spec(local_nparts, parts, aval) for parts, aval in zip(local_in_parts, abstract_args)] input_indices = [pxla.spec_to_indices(aval.shape, spec) if spec is not None else None for aval, spec in zip(abstract_args, input_specs)] handle_args = partial(pxla.shard_args, compiled.local_devices(), input_indices) handle_outs = _avals_to_results_handler(nrep, local_nparts, # type: ignore local_out_parts, local_out_avals) return partial(_execute_spatially_partitioned, compiled, handle_args, handle_outs)
def test_unknown_backend_error(self): with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"): xb.get_backend("foo")
def device_under_test(): return FLAGS.jax_test_dut or xla_bridge.get_backend().platform