def _initialize_outfeed_receiver( clients: Optional[List[XlaLocalClient]] = None, max_callback_queue_size_bytes: int = int(256 * 1e6)): """Creates and starts the outfeed_receiver. This function is called lazily only when we compile an id_tap. Args: * clients: the list of clients (backends) on whose devices to listen on. * max_callback_queue_size_bytes: an optional integer to bound the maximum size of arrays in the callback queue. When this limit is reached the device listener pauses. """ try: outfeed_receiver_module = xla_extension.outfeed_receiver except AttributeError as err: raise NotImplementedError( "id_tap works only with jaxlib version 0.1.51 and higher") from err with _outfeed_receiver.lock: if _outfeed_receiver.receiver is not None: return if clients is None: # By default, all devices on all backends clients = xla_client._get_local_backends().values( ) # type: ignore[protected-class] # Drop the interpreter clients clients = tuple([ c for c in clients if c.platform != "interpreter" ]) # type: ignore devices = list( itertools.chain(*[backend.devices() for backend in clients])) _outfeed_receiver.clients = clients # type: ignore[assignment] _outfeed_receiver.devices = devices # type: ignore[assignment] logging.vlog( 2, f"Starting outfeed_receiver for {[str(d) for d in devices]}. " f"max_callback_queue_size_bytes={max_callback_queue_size_bytes}") _outfeed_receiver.receiver = outfeed_receiver_module.start( _outfeed_receiver_callback, tuple(clients), max_callback_queue_size_bytes) def exit_handler(): # Prevent logging usage during compilation, gives errors under pytest xla._on_exit = True barrier_wait("at_exit") atexit.register(exit_handler) # We wait as long as we have callbacks
def outfeed_receiver(*, timeout_sec=10, backends: Optional[Sequence[str]] = None, devices: Optional[Sequence[XlaDevice]] = None, receiver_name=""): # TODO: better timeout management. """Starts receivers for the :func:`id_tap` outfeed from several devices. The receivers will run in a threadpool. The tapped functions will be invoked in those threads. If a tap function raises an exception, an error is printed, but the receiving continues until the body of the context manager terminates and all outfeeds from all devices have been received. Only then will a :exc:`TapFunctionException` be raised. Args: backends: (optional) sequence of backend names for which to listen. Will listen to all devices on those backends. By default, listed to all devices on all known backends. devices: (optional) sequence of devices to listed to. At most one of `backends` or `devices` must be given. receiver_name: (optional) a name to use with debug logging Usage:: with outfeed_receiver(): jax.jit(func)(args) ... jax.pmap(another_func)(args) The ``outfeed_receiver`` must be started outside any jitted computation. """ if not devices: backends = backends or xla_client._get_local_backends().keys() devices = tuple( itertools.chain(*[api.devices(backend) for backend in backends])) else: if backends: raise ValueError( "At most one of `devices` or `backends` must be given.") executor = futures.ThreadPoolExecutor( thread_name_prefix=f"outfeed_receiver_{receiver_name}", max_workers=len(devices)) count_tap_exceptions = 0 def device_receiver_loop(device: XlaDevice) -> XlaDevice: """Polls the outfeed for a device in a loop.""" nonlocal count_tap_exceptions while (True): consumer_id, arrays = _receive_outfeed(device, receiver_name) if _LOGGING: logging.info( f"[{receiver_name}:{device}] Outfeed received for consumer {consumer_id} " + (" ".join([f"({a.dtype}{a.shape})" for a in arrays]))) if consumer_id == _end_consumer: assert not arrays if _LOGGING: logging.info( f"[{receiver_name}:{device}] Outfeed received END_OUTFEED" ) return device consumer = _consumer_registry_by_id.get(consumer_id) if consumer is None: logging.error( f"Ignoring received outfeed for unknown tap consumer") count_tap_exceptions += 1 continue # We need to read the entire outfeed try: arg = api.tree_unflatten(consumer.arg_treedef, arrays) consumer.func(arg, **dict( consumer.kwargs)) # type: ignore[attribute-error] except Exception as e: logging.error( f"Postponing exception raised in tap function: {str(e)}\n{traceback.format_exc()}" ) count_tap_exceptions += 1 # We continue for now, we need to keep reading the outfeed receiver_futures = [ executor.submit(device_receiver_loop, d) for d in devices ] # Register a callback to raise errors if any. These exception come from # bugs in our code, not from the tap functions. for rf in receiver_futures: rf.add_done_callback(lambda rf: rf.result()) global _outfeed_receiver_started if _outfeed_receiver_started: raise ValueError( "At most one outfeed_receiver can be running at once.") _outfeed_receiver_started = True xla.can_execute_outfeed_computations = True try: yield finally: for d in devices: # Signal the end of printing api.jit(lambda x: id_tap(_end_consumer, None, result=x), device=d)(0) # type: ignore[arg-type] xla.can_execute_outfeed_computations = False _outfeed_receiver_started = False for f in futures.as_completed(receiver_futures, timeout=timeout_sec): finished_device = f.result() # Throw exceptions here if _LOGGING: logging.info( f"[{receiver_name}:{finished_device} Outfeed receiver finished" ) if count_tap_exceptions > 0: raise TapFunctionException