예제 #1
0
def _initialize_outfeed_receiver(
    clients: Optional[List[XlaLocalClient]] = None,
    max_callback_queue_size_bytes: int = int(256 * 1e6)):
    """Creates and starts the outfeed_receiver.

  This function is called lazily only when we compile an id_tap.

  Args:
    * clients: the list of clients (backends) on whose devices to listen on.
    * max_callback_queue_size_bytes: an optional integer to bound the maximum
      size of arrays in the callback queue. When this limit is reached the
      device listener pauses.
  """
    try:
        outfeed_receiver_module = xla_extension.outfeed_receiver
    except AttributeError as err:
        raise NotImplementedError(
            "id_tap works only with jaxlib version 0.1.51 and higher") from err

    with _outfeed_receiver.lock:
        if _outfeed_receiver.receiver is not None:
            return

        if clients is None:
            # By default, all devices on all backends
            clients = xla_client._get_local_backends().values(
            )  # type: ignore[protected-class]
            # Drop the interpreter clients
            clients = tuple([
                c for c in clients if c.platform != "interpreter"
            ])  # type: ignore
        devices = list(
            itertools.chain(*[backend.devices() for backend in clients]))
        _outfeed_receiver.clients = clients  # type: ignore[assignment]
        _outfeed_receiver.devices = devices  # type: ignore[assignment]
        logging.vlog(
            2, f"Starting outfeed_receiver for {[str(d) for d in devices]}. "
            f"max_callback_queue_size_bytes={max_callback_queue_size_bytes}")
        _outfeed_receiver.receiver = outfeed_receiver_module.start(
            _outfeed_receiver_callback, tuple(clients),
            max_callback_queue_size_bytes)

        def exit_handler():
            # Prevent logging usage during compilation, gives errors under pytest
            xla._on_exit = True
            barrier_wait("at_exit")

        atexit.register(exit_handler)  # We wait as long as we have callbacks
예제 #2
0
def outfeed_receiver(*,
                     timeout_sec=10,
                     backends: Optional[Sequence[str]] = None,
                     devices: Optional[Sequence[XlaDevice]] = None,
                     receiver_name=""):
    # TODO: better timeout management.
    """Starts receivers for the :func:`id_tap` outfeed from several devices.

  The receivers will run in a threadpool. The tapped functions will be invoked
  in those threads. If a tap function raises an exception, an error is
  printed, but the receiving continues until the body of the context manager
  terminates and all outfeeds from all devices have been received. Only then
  will a :exc:`TapFunctionException` be raised.

  Args:
    backends: (optional) sequence of backend names for which to listen.
      Will listen to all devices on those backends. By default, listed to
      all devices on all known backends.
    devices: (optional) sequence of devices to listed to. At most one
      of `backends` or `devices` must be given.
    receiver_name: (optional) a name to use with debug logging
  Usage::

    with outfeed_receiver():
      jax.jit(func)(args)
      ...
      jax.pmap(another_func)(args)

  The ``outfeed_receiver`` must be started outside any jitted computation.

  """
    if not devices:
        backends = backends or xla_client._get_local_backends().keys()
        devices = tuple(
            itertools.chain(*[api.devices(backend) for backend in backends]))
    else:
        if backends:
            raise ValueError(
                "At most one of `devices` or `backends` must be given.")
    executor = futures.ThreadPoolExecutor(
        thread_name_prefix=f"outfeed_receiver_{receiver_name}",
        max_workers=len(devices))

    count_tap_exceptions = 0

    def device_receiver_loop(device: XlaDevice) -> XlaDevice:
        """Polls the outfeed for a device in a loop."""
        nonlocal count_tap_exceptions
        while (True):
            consumer_id, arrays = _receive_outfeed(device, receiver_name)
            if _LOGGING:
                logging.info(
                    f"[{receiver_name}:{device}] Outfeed received for consumer {consumer_id} "
                    + (" ".join([f"({a.dtype}{a.shape})" for a in arrays])))
            if consumer_id == _end_consumer:
                assert not arrays
                if _LOGGING:
                    logging.info(
                        f"[{receiver_name}:{device}] Outfeed received END_OUTFEED"
                    )
                return device
            consumer = _consumer_registry_by_id.get(consumer_id)
            if consumer is None:
                logging.error(
                    f"Ignoring received outfeed for unknown tap consumer")
                count_tap_exceptions += 1
                continue  # We need to read the entire outfeed
            try:
                arg = api.tree_unflatten(consumer.arg_treedef, arrays)
                consumer.func(arg, **dict(
                    consumer.kwargs))  # type: ignore[attribute-error]
            except Exception as e:
                logging.error(
                    f"Postponing exception raised in tap function: {str(e)}\n{traceback.format_exc()}"
                )
                count_tap_exceptions += 1
                # We continue for now, we need to keep reading the outfeed

    receiver_futures = [
        executor.submit(device_receiver_loop, d) for d in devices
    ]
    # Register a callback to raise errors if any. These exception come from
    # bugs in our code, not from the tap functions.
    for rf in receiver_futures:
        rf.add_done_callback(lambda rf: rf.result())
    global _outfeed_receiver_started
    if _outfeed_receiver_started:
        raise ValueError(
            "At most one outfeed_receiver can be running at once.")
    _outfeed_receiver_started = True
    xla.can_execute_outfeed_computations = True
    try:
        yield
    finally:
        for d in devices:  # Signal the end of printing
            api.jit(lambda x: id_tap(_end_consumer, None, result=x),
                    device=d)(0)  # type: ignore[arg-type]
        xla.can_execute_outfeed_computations = False
        _outfeed_receiver_started = False
        for f in futures.as_completed(receiver_futures, timeout=timeout_sec):
            finished_device = f.result()  # Throw exceptions here
            if _LOGGING:
                logging.info(
                    f"[{receiver_name}:{finished_device} Outfeed receiver finished"
                )
        if count_tap_exceptions > 0:
            raise TapFunctionException