예제 #1
0
 def cpu_fallback_warning(self):
     with warnings.catch_warnings(record=True) as w:
         warnings.simplefilter("always")
         xb.get_backend()
         self.assertLen(w, 1)
         msg = str(w[-1].message)
         self.assertIn("No GPU/TPU found, falling back to CPU", msg)
예제 #2
0
def start_trace(log_dir, create_perfetto_link: bool = False):
    """Starts a profiler trace.

  The trace will capture CPU, GPU, and/or TPU activity, including Python
  functions and JAX on-device operations. Use ``stop_trace()`` to end the trace
  and save the results to ``log_dir``.

  The resulting trace can be viewed with TensorBoard. Note that TensorBoard
  doesn't need to be running when collecting the trace.

  Only once trace may be collected a time. A RuntimeError will be raised if
  ``start_trace()`` is called while another trace is running.

  Args:
    log_dir: The directory to save the profiler trace to (usually the
      TensorBoard log directory).
    create_perfetto_link: A boolean which, if true, creates and prints link to
      the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will
      block until the link is opened and Perfetto loads the trace.
  """
    with _profile_state.lock:
        if _profile_state.profile_session is not None:
            raise RuntimeError("Profile has already been started. "
                               "Only one profile may be run at a time.")
        # Make sure backends are initialized before creating a profiler
        # session. Otherwise on Cloud TPU, libtpu may not be initialized before
        # creating the tracer, which will cause the TPU tracer initialization to
        # fail and no TPU operations will be included in the profile.
        xla_bridge.get_backend()

        _profile_state.profile_session = xla_client.profiler.ProfilerSession()
        _profile_state.create_perfetto_link = create_perfetto_link
        _profile_state.log_dir = log_dir
예제 #3
0
 def test_factory_returns_none(self):
     xb.register_backend_factory("none", lambda: None, priority=10)
     default_backend = xb.get_backend()
     self.assertEqual(default_backend.platform, "cpu")
     with self.assertRaisesRegex(
             RuntimeError, "Backend 'none' failed to initialize: "
             "Could not initialize backend 'none'"):
         xb.get_backend("none")
예제 #4
0
    def test_backend_init_error(self):
        def factory():
            raise RuntimeError("I'm not a real backend")

        xb.register_backend_factory("error", factory, priority=10)
        # No error raised if there's a fallback backend.
        default_backend = xb.get_backend()
        self.assertEqual(default_backend.platform, "cpu")

        with self.assertRaisesRegex(RuntimeError, "I'm not a real backend"):
            xb.get_backend("error")
예제 #5
0
def device_memory_profile(backend: Optional[str] = None) -> bytes:
    """Captures a JAX device memory profile as ``pprof``-format protocol buffer.

  A device memory profile is a snapshot of the state of memory, that describes the JAX
  :class:`jax.DeviceArray` and executable objects present in memory and their
  allocation sites.

  For more information how to use the device memory profiler, see
  :doc:`/device_memory_profiling`.

  The profiling system works by instrumenting JAX on-device allocations,
  capturing a Python stack trace for each allocation. The instrumentation is
  always enabled; :func:`device_memory_profile` provides an API to capture it.

  The output of :func:`device_memory_profile` is a binary protocol buffer that
  can be interpreted and visualized by the `pprof tool
  <https://github.com/google/pprof>`_.

  Args:
    backend: optional; the name of the JAX backend for which the device memory
      profile should be collected.

  Returns:
    A byte string containing a binary `pprof`-format protocol buffer.
  """
    return xla_client.heap_profile(xla_bridge.get_backend(backend))
예제 #6
0
    def test_specific_platform(self):
        self._register_factory("platform_A", 20)
        self._register_factory("platform_B", 10)

        backend = xb.get_backend("platform_B")
        self.assertEqual(backend.platform, "platform_B")
        # All backends initialized.
        self.assertEqual(len(xb._backends), len(xb._backend_factories))
예제 #7
0
파일: dlpack.py 프로젝트: frederikwilde/jax
def from_dlpack(dlpack):
    """Returns a ``DeviceArray`` representation of a DLPack tensor.

  The returned ``DeviceArray`` shares memory with ``dlpack``.

  Args:
    dlpack: a DLPack tensor, on either CPU or GPU.
  """
    cpu_backend = xla_bridge.get_backend("cpu")
    try:
        gpu_backend = xla_bridge.get_backend("gpu")
    except RuntimeError:
        gpu_backend = None
    buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
        dlpack, cpu_backend, gpu_backend)

    xla_shape = buf.xla_shape()
    assert not xla_shape.is_tuple()
    aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
    return device_array.make_device_array(aval, buf.device(), buf)  # pytype: disable=attribute-error
  def testTorchToJaxFailure(self):
    x = torch.arange(6).reshape((2, 3))
    y = torch.utils.dlpack.to_dlpack(x[:, :2])

    backend = xla_bridge.get_backend()
    client = getattr(backend, "client", backend)

    regex_str = (r'UNIMPLEMENTED: Only DLPack tensors with trivial \(compact\) '
                 r'striding are supported')
    with self.assertRaisesRegex(RuntimeError, regex_str):
      xla_client._xla.dlpack_managed_tensor_to_buffer(
          y, client)
예제 #9
0
  def test_jax_platforms_flag(self):
    self._register_factory("platform_A", 20)
    self._register_factory("platform_B", 10)

    orig_jax_platforms = config._read("jax_platforms")
    try:
      config.FLAGS.jax_platforms = "cpu,platform_A"

      backend = xb.get_backend()
      self.assertEqual(backend.platform, "cpu")
      # Only specified backends initialized.
      self.assertEqual(len(xb._backends), 2)

      backend = xb.get_backend("platform_A")
      self.assertEqual(backend.platform, "platform_A")

      with self.assertRaisesRegex(RuntimeError, "Unknown backend platform_B"):
        backend = xb.get_backend("platform_B")

    finally:
      config.FLAGS.jax_platforms = orig_jax_platforms
예제 #10
0
def start_server(port: int):
    """Starts the profiler server on port `port`.

  Using the "TensorFlow profiler" feature in `TensorBoard
  <https://www.tensorflow.org/tensorboard>`_ 2.2 or newer, you can
  connect to the profiler server and sample execution traces that show CPU,
  GPU, and/or TPU device activity.
  """
    global _profiler_server
    if _profiler_server is not None:
        raise ValueError("Only one profiler server can be active at a time.")

    # Make sure backends are initialized before creating a profiler
    # session. Otherwise on Cloud TPU, libtpu may not be initialized before
    # creating the tracer, which will cause the TPU tracer initialization to
    # fail and no TPU operations will be included in the profile.
    # NOTE(skyewm): I'm not sure this is necessary for start_server (is definitely
    # is for start_trace), but I'm putting it here to be safe.
    xla_bridge.get_backend()

    _profiler_server = xla_client.profiler.start_server(port)
    return _profiler_server
예제 #11
0
파일: dispatch.py 프로젝트: jbampton/jax
def _xla_callable_device(nreps, backend, device, arg_devices):
  if nreps > 1:
    if device is not None or backend is not None:
      raise ValueError(f"can't specify device or backend for jit-of-pmap, "
                       f"got device={device} and backend={backend}")
    return None
  else:
    if device is None and backend is None:
      return _device_from_arg_devices(arg_devices)
    elif device is not None and backend is None:
      return device
    elif device is None and backend is not None:
      return xb.get_backend(backend).get_default_device_assignment(1)[0]
    else:
      assert False  # Unreachable given the error check in _xla_callable
예제 #12
0
파일: dispatch.py 프로젝트: romanngg/jax
def _xla_callable_device(nreps, backend, device,
                         arg_devices) -> Optional[Device]:
    if nreps > 1:
        if device is not None or backend is not None:
            raise ValueError(
                f"can't specify device or backend for jit-of-pmap, "
                f"got device={device} and backend={backend}")
        return None
    else:
        # TODO(skye): dedup with C++ jit logic for determining jit device?
        if device is not None:
            assert backend is None
            return device

        if backend is not None:
            return xb.get_backend(backend).get_default_device_assignment(1)[0]

        arg_device = _device_from_arg_devices(arg_devices)
        if arg_device is not None:
            return arg_device

        return config.jax_default_device
예제 #13
0
    def test_no_devices(self):
        self._register_factory("no_devices", -10, device_count=0)
        default_backend = xb.get_backend()
        self.assertEqual(default_backend.platform, "cpu")
        with self.assertRaisesRegex(
                RuntimeError, "Backend 'no_devices' failed to initialize: "
                "Backend 'no_devices' provides no devices."):
            xb.get_backend("no_devices")

        self._reset_backend_state()

        self._register_factory("no_devices2", 10, device_count=0)
        default_backend = xb.get_backend()
        self.assertEqual(default_backend.platform, "cpu")
        with self.assertRaisesRegex(
                RuntimeError, "Backend 'no_devices2' failed to initialize: "
                "Backend 'no_devices2' provides no devices."):
            xb.get_backend("no_devices2")
예제 #14
0
def _sharded_callable(
        fun: lu.WrappedFun, nparts: Optional[int],
        in_parts: Tuple[pxla.PartitionsOrReplicated, ...],
        out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]],
        local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]],
        local_out_parts_thunk: Callable[[], Optional[Tuple[
            pxla.PartitionsOrReplicated,
            ...]]], local_nparts: Optional[int], name: str, *abstract_args):
    nrep = 1

    if local_in_parts is None:
        local_in_parts = in_parts

    global_abstract_args = [
        pxla.get_global_aval(arg, parts,
                             lparts) for arg, parts, lparts in safe_zip(
                                 abstract_args, in_parts, local_in_parts)
    ]

    if logging.vlog_is_on(2):
        logging.vlog(2, "abstract_args: %s", abstract_args)
        logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
        logging.vlog(2, "in_parts: %s", in_parts)
        logging.vlog(2, "local_in_parts: %s", local_in_parts)

    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
        fun, global_abstract_args)

    platform = xb.get_backend().platform

    nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
    assert nparts is not None
    if nparts > xb.device_count():
        raise ValueError(
            f"sharded_jit computation requires {nparts} devices, "
            f"but only {xb.device_count()} devices are available.")
    if xb.local_device_count() < nparts < xb.device_count():
        raise NotImplementedError(
            f"sharded_jit across multiple hosts must use all available devices. "
            f"Got {nparts} out of {xb.device_count()} requested devices "
            f"(local device count: {xb.local_device_count()})")

    if local_nparts is None:
        if nparts > xb.local_device_count():
            raise ValueError(
                "Specify 'local_nparts' when using cross-process sharded_jit "
                "and all inputs and outputs are replicated.")
        else:
            local_nparts = nparts
    if local_nparts > xb.local_device_count():
        raise ValueError(
            f"sharded_jit computation requires {local_nparts} local devices, "
            f"but only {xb.local_device_count()} local devices are available.")

    if logging.vlog_is_on(2):
        logging.vlog(2, "nparts: %d  local_nparts: %d", nparts, local_nparts)

    out_parts = out_parts_thunk()

    local_out_parts = local_out_parts_thunk()
    if local_out_parts is None:
        local_out_parts = out_parts

    if logging.vlog_is_on(2):
        logging.vlog(2, "out_parts: %s", out_parts)
        logging.vlog(2, "local_out_parts: %s", local_out_parts)

    local_out_avals = [
        pxla.get_local_aval(out, parts,
                            lparts) for out, parts, lparts in safe_zip(
                                global_out_avals, out_parts, local_out_parts)
    ]

    log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
    logging.log(log_priority, "Compiling %s for %d devices with args %s.",
                fun.__name__, nparts, global_abstract_args)

    axis_env = xla.AxisEnv(nrep, (), ())
    unordered_effects = [
        eff for eff in jaxpr.effects if eff not in core.ordered_effects
    ]
    ordered_effects = [
        eff for eff in jaxpr.effects if eff in core.ordered_effects
    ]
    module, _ = mlir.lower_jaxpr_to_module(
        f"spjit_{fun.__name__}",
        core.ClosedJaxpr(jaxpr, consts),
        unordered_effects,
        ordered_effects,
        platform=platform,
        axis_context=mlir.ReplicaAxisContext(axis_env),
        name_stack=new_name_stack(wrap_name(name, "sharded_jit")),
        donated_args=[False] * len(in_parts),
        arg_shardings=safe_map(xla.sharding_to_proto, in_parts),
        result_shardings=safe_map(xla.sharding_to_proto, out_parts))
    built = xc._xla.mlir.mlir_module_to_xla_computation(
        mlir.module_to_string(module), use_tuple_args=False, return_tuple=True)

    if nparts <= xb.local_device_count():
        devices = xb.local_devices()[:nparts]
    else:
        assert nparts == xb.device_count()
        devices = xb.devices()
    device_assignment = np.array([[d for d in devices]])
    device_assignment = np.reshape(device_assignment, (-1, nparts))
    # device_assignment = None  # TODO(skye): replace with default device assignment?

    compiled = dispatch.backend_compile(
        xb.get_backend(), built,
        xb.get_compile_options(nrep, nparts, device_assignment))

    input_specs = [
        pxla.partitioned_sharding_spec(local_nparts, parts, aval)
        for parts, aval in zip(local_in_parts, abstract_args)
    ]
    input_indices = [
        pxla.spec_to_indices(aval.shape, spec) if spec is not None else None
        for aval, spec in zip(abstract_args, input_specs)
    ]

    handle_args = partial(pxla.shard_args, compiled.local_devices(),
                          input_indices)
    handle_outs = _avals_to_results_handler(
        nrep,
        local_nparts,  # type: ignore
        local_out_parts,
        local_out_avals)
    return partial(_execute_spatially_partitioned, compiled, handle_args,
                   handle_outs)
예제 #15
0
def device_under_test():
  return getattr(FLAGS, 'jax_test_dut', None) or xla_bridge.get_backend().platform
예제 #16
0
def is_device_rocm():
    return xla_bridge.get_backend().platform_version.startswith('rocm')
예제 #17
0
파일: dispatch.py 프로젝트: jbampton/jax
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
                       donated_invars, *arg_specs):
  if device is not None and backend is not None:
    raise ValueError("can't specify both a device and a backend for jit, "
                     "got device={} and backend={}".format(device, backend))

  abstract_args, arg_devices = util.unzip2(arg_specs)
  with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
                        "for jit in {elapsed_time} sec"):
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
        fun, abstract_args, pe.debug_info_final(fun, "jit"))
  if any(isinstance(c, core.Tracer) for c in consts):
    raise UnexpectedTracerError("Encountered an unexpected tracer.")
  jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
  consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
  pruned_arg_specs = (a for i, a in enumerate(arg_specs) if i in kept_var_idx)
  abstract_args, arg_devices = util.unzip2(pruned_arg_specs)
  donated_invars = [
      x for i, x in enumerate(donated_invars) if i in kept_var_idx
  ]
  map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr)))
  jaxpr = apply_outfeed_rewriter(jaxpr)

  nreps = jaxpr_replicas(jaxpr)
  device = _xla_callable_device(nreps, backend, device, arg_devices)
  backend = xb.get_device_backend(device) if device else xb.get_backend(backend)

  # Computations that only produce constants and/or only rearrange their inputs,
  # which are often produced from partial evaluation, don't need compilation,
  # and don't need to evaluate their arguments.
  if not jaxpr.eqns:
    return XlaComputation(
        name, None, True, None, jaxpr=jaxpr, consts=consts, device=device,
        in_avals=abstract_args, out_avals=out_avals, kept_var_idx=kept_var_idx)

  if not _on_exit:
    log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
    if len(abstract_args) > 10:
      msg = f"Compiling {fun.__name__} ({id(fun)}) for {len(abstract_args)} args."
    else:
      msg = f"Compiling {fun.__name__} ({id(fun)} for args {abstract_args}."
    logging.log(log_priority, msg)

  if nreps > 1:
    warnings.warn(
        f"The jitted function {name} includes a pmap. Using "
         "jit-of-pmap can lead to inefficient data movement, as the outer jit "
         "does not preserve sharded data representations and instead collects "
         "input and output arrays onto a single device. "
         "Consider removing the outer jit unless you know what you're doing. "
         "See https://github.com/google/jax/issues/2926.")

  if nreps > xb.device_count(backend):
    raise ValueError(
        f"compiling computation `{name}` that requires {nreps} replicas, but "
        f"only {xb.device_count(backend)} XLA devices are available.")

  if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
    raise NotImplementedError(
        "jit of multi-host pmap not implemented (and jit-of-pmap can cause "
        "extra data movement anyway, so maybe you don't want it after all).")

  # pass long arg lists as tuple for TPU
  tuple_args = len(abstract_args) > 100
  axis_env = xla.AxisEnv(nreps, (), ())
  name_stack = xla.new_name_stack(xla.wrap_name(name, 'jit'))
  closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
  module: Union[str, xc.XlaComputation]
  module_name = f"jit_{fun.__name__}"
  if config.jax_enable_mlir:
    module = mlir.lower_jaxpr_to_module(
        module_name, closed_jaxpr, backend.platform,
        mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
  else:
    module = xla.lower_jaxpr_to_xla_module(
        module_name, closed_jaxpr, backend.platform, axis_env,
        name_stack, tuple_args, donated_invars, replicated_args=None,
        arg_partitions=None, out_partitions=None)
  return XlaComputation(
      name, module, False, donated_invars, nreps=nreps, device=device,
      backend=backend, tuple_args=tuple_args, in_avals=abstract_args,
      out_avals=out_avals, kept_var_idx=kept_var_idx)
예제 #18
0
파일: dispatch.py 프로젝트: cloudhan/jax
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
                       donated_invars, always_lower: bool, keep_unused: bool,
                       *arg_specs):
    """Lower into XLA.

  Args:
    always_lower: If `True`, even trivial programs (not doing any computation
      such as lambda x: x) will be lowered into an XLA program.
    keep_unused: If `False` (the default), arguments that JAX determines to be
      unused by `fun` *may* be dropped from resulting compiled XLA executables.
      Such arguments will not be transferred to the device nor provided to the
      underlying executable. If `True`, unused arguments will not be pruned.
  """
    if device is not None and backend is not None:
        raise ValueError("can't specify both a device and a backend for jit, "
                         "got device={} and backend={}".format(
                             device, backend))
    abstract_args, arg_devices = util.unzip2(arg_specs)
    if fun.in_type is not None:
        abstract_args, which_explicit = util.unzip2(fun.in_type)
    else:
        which_explicit = None
    with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
                          "for jit in {elapsed_time} sec"):
        jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
            fun, abstract_args, pe.debug_info_final(fun, "jit"),
            which_explicit)
    if any(isinstance(c, core.Tracer) for c in consts):
        raise UnexpectedTracerError("Encountered an unexpected tracer.")
    # TODO(mattjj): handle argument pruning w/ dynamic shapes
    if fun.in_type is None and not keep_unused:
        jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
        consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
        abstract_args, arg_devices = util.unzip2(
            [a for i, a in enumerate(arg_specs) if i in kept_var_idx])
        donated_invars = [
            x for i, x in enumerate(donated_invars) if i in kept_var_idx
        ]
        del kept_const_idx
    else:
        kept_var_idx = set(range(len(abstract_args)))
    map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr)))
    jaxpr = apply_outfeed_rewriter(jaxpr)

    nreps = jaxpr_replicas(jaxpr)
    device = _xla_callable_device(nreps, backend, device, arg_devices)
    backend = xb.get_device_backend(device) if device else xb.get_backend(
        backend)

    if (config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr)
            and not _backend_supports_unbounded_dynamic_shapes(backend)):
        jaxpr, consts = pe.pad_jaxpr(jaxpr, consts)

    # Computations that only produce constants and/or only rearrange their inputs,
    # which are often produced from partial evaluation, don't need compilation,
    # and don't need to evaluate their arguments.
    if not jaxpr.eqns and not always_lower:
        return XlaComputation(name,
                              None,
                              True,
                              None,
                              None,
                              jaxpr=jaxpr,
                              consts=consts,
                              device=device,
                              in_avals=abstract_args,
                              out_avals=out_avals,
                              has_unordered_effects=False,
                              ordered_effects=[],
                              kept_var_idx=kept_var_idx,
                              keepalive=None)

    if not _on_exit:
        log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
        if len(abstract_args) > 10:
            msg = f"Compiling {fun.__name__} ({id(fun)}) for {len(abstract_args)} args."
        else:
            msg = f"Compiling {fun.__name__} ({id(fun)} for args {abstract_args}."
        logging.log(log_priority, msg)

    if nreps > 1:
        warnings.warn(
            f"The jitted function {name} includes a pmap. Using "
            "jit-of-pmap can lead to inefficient data movement, as the outer jit "
            "does not preserve sharded data representations and instead collects "
            "input and output arrays onto a single device. "
            "Consider removing the outer jit unless you know what you're doing. "
            "See https://github.com/google/jax/issues/2926.")

    if nreps > xb.device_count(backend):
        raise ValueError(
            f"compiling computation `{name}` that requires {nreps} replicas, but "
            f"only {xb.device_count(backend)} XLA devices are available.")

    if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
        raise NotImplementedError(
            "jit of multi-host pmap not implemented (and jit-of-pmap can cause "
            "extra data movement anyway, so maybe you don't want it after all)."
        )

    # pass long arg lists as tuple for TPU
    tuple_args = len(abstract_args) > 100
    axis_env = xla.AxisEnv(nreps, (), ())
    name_stack = util.new_name_stack(util.wrap_name(name, 'jit'))
    closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
    module_name = f"jit_{fun.__name__}"
    unordered_effects = [
        eff for eff in closed_jaxpr.effects if eff not in core.ordered_effects
    ]
    ordered_effects = [
        eff for eff in closed_jaxpr.effects if eff in core.ordered_effects
    ]
    module, keepalive = mlir.lower_jaxpr_to_module(
        module_name, closed_jaxpr,
        unordered_effects, ordered_effects, backend.platform,
        mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
    return XlaComputation(name,
                          module,
                          False,
                          donated_invars,
                          which_explicit,
                          nreps=nreps,
                          device=device,
                          backend=backend,
                          tuple_args=tuple_args,
                          in_avals=abstract_args,
                          out_avals=out_avals,
                          has_unordered_effects=bool(unordered_effects),
                          ordered_effects=ordered_effects,
                          kept_var_idx=kept_var_idx,
                          keepalive=keepalive)
예제 #19
0
def is_device_cuda():
    return xla_bridge.get_backend().platform_version.startswith('cuda')
예제 #20
0
def _sharded_callable(
    fun: lu.WrappedFun, nparts: Optional[int],
    in_parts: Tuple[pxla.PartitionsOrReplicated, ...],
    out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]],
    local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]],
    local_out_parts_thunk: Callable[[], Optional[Tuple[pxla.PartitionsOrReplicated, ...]]],
    local_nparts: Optional[int], name: str, *abstract_args):
  nrep = 1

  if local_in_parts is None:
    local_in_parts = in_parts

  global_abstract_args = [pxla.get_global_aval(arg, parts, lparts)
                          for arg, parts, lparts
                          in safe_zip(abstract_args, in_parts, local_in_parts)]

  if logging.vlog_is_on(2):
    logging.vlog(2, "abstract_args: %s", abstract_args)
    logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
    logging.vlog(2, "in_parts: %s", in_parts)
    logging.vlog(2, "local_in_parts: %s", local_in_parts)

  jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(fun, global_abstract_args)

  if xb.get_backend().platform not in ["tpu", "gpu"]:
    # TODO(skye): fall back to regular jit?
    raise ValueError("sharded_jit not supported for " +
                     xb.get_backend().platform)

  nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
  assert nparts is not None
  if nparts > xb.device_count():
    raise ValueError(
        f"sharded_jit computation requires {nparts} devices, "
        f"but only {xb.device_count()} devices are available.")
  if xb.local_device_count() < nparts < xb.device_count():
    raise NotImplementedError(
        f"sharded_jit across multiple hosts must use all available devices. "
        f"Got {nparts} out of {xb.device_count()} requested devices "
        f"(local device count: {xb.local_device_count()})")

  if local_nparts is None:
    if nparts > xb.local_device_count():
      raise ValueError(
        "Specify 'local_nparts' when using cross-process sharded_jit "
        "and all inputs and outputs are replicated.")
    else:
      local_nparts = nparts
  if local_nparts > xb.local_device_count():
    raise ValueError(
        f"sharded_jit computation requires {local_nparts} local devices, "
        f"but only {xb.local_device_count()} local devices are available.")

  if logging.vlog_is_on(2):
    logging.vlog(2, "nparts: %d  local_nparts: %d", nparts, local_nparts)

  out_parts = out_parts_thunk()

  local_out_parts = local_out_parts_thunk()
  if local_out_parts is None:
    local_out_parts = out_parts

  if logging.vlog_is_on(2):
    logging.vlog(2, "out_parts: %s", out_parts)
    logging.vlog(2, "local_out_parts: %s", local_out_parts)

  local_out_avals = [pxla.get_local_aval(out, parts, lparts)
                     for out, parts, lparts
                     in safe_zip(global_out_avals, out_parts, local_out_parts)]

  log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
  logging.log(log_priority,
              "Compiling %s for %d devices with args %s.",
              fun.__name__, nparts, global_abstract_args)

  c = xb.make_computation_builder("spjit_{}".format(fun.__name__))
  xla_consts = _map(partial(xb.constant, c), consts)
  xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
  axis_env = xla.AxisEnv(nrep, (), ())
  out_nodes = xla.jaxpr_subcomp(
      c, jaxpr, None, axis_env, xla_consts,
      extend_name_stack(wrap_name(name, "sharded_jit")), *xla_args)
  out_tuple = xb.with_sharding(c, out_parts, xops.Tuple, c, out_nodes)
  built = c.Build(out_tuple)

  if nparts <= xb.local_device_count():
    devices = xb.local_devices()[:nparts]
  else:
    assert nparts == xb.device_count()
    devices = xb.devices()
  device_assignment = np.array([[d.id for d in devices]])
  device_assignment = np.reshape(device_assignment, (-1, nparts))
  # device_assignment = None  # TODO(skye): replace with default device assignment?

  compiled = xla.backend_compile(
      xb.get_backend(), built,
      xb.get_compile_options(nrep, nparts, device_assignment))

  input_specs = [
      pxla.partitioned_sharding_spec(local_nparts, parts, aval)
      for parts, aval in zip(local_in_parts, abstract_args)]
  input_indices = [pxla.spec_to_indices(aval.shape, spec)
                   if spec is not None else None
                   for aval, spec in zip(abstract_args, input_specs)]

  handle_args = partial(pxla.shard_args, compiled.local_devices(),
                        input_indices)
  handle_outs = _avals_to_results_handler(nrep, local_nparts,  # type: ignore
                                          local_out_parts, local_out_avals)
  return partial(_execute_spatially_partitioned, compiled, handle_args,
                 handle_outs)
예제 #21
0
 def test_unknown_backend_error(self):
     with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"):
         xb.get_backend("foo")
예제 #22
0
def device_under_test():
    return FLAGS.jax_test_dut or xla_bridge.get_backend().platform