def device_receiver_loop(device: XlaDevice) -> XlaDevice: """Polls the outfeed for a device in a loop.""" nonlocal count_tap_exceptions while (True): consumer_id, arrays = _receive_outfeed(device, receiver_name) if _LOGGING: logging.info( f"[{receiver_name}:{device}] Outfeed received for consumer {consumer_id} " + (" ".join([f"({a.dtype}{a.shape})" for a in arrays]))) if consumer_id == _end_consumer: assert not arrays if _LOGGING: logging.info( f"[{receiver_name}:{device}] Outfeed received END_OUTFEED" ) return device consumer = _consumer_registry_by_id.get(consumer_id) if consumer is None: logging.error( f"Ignoring received outfeed for unknown tap consumer") count_tap_exceptions += 1 continue # We need to read the entire outfeed try: arg = api.tree_unflatten(consumer.arg_treedef, arrays) consumer.func(arg, **dict( consumer.kwargs)) # type: ignore[attribute-error] except Exception as e: logging.error( f"Postponing exception raised in tap function: {str(e)}\n{traceback.format_exc()}" ) count_tap_exceptions += 1
def _outfeed_receiver_callback(device, consumer_id, arrays): #logging.vlog( # 2, f"Outfeed received on device {device} for consumer {consumer_id} " + # (" ".join([f"({a.dtype}{a.shape})" for a in arrays]))) consumer = _outfeed_receiver.consumer_registry_by_id.get(consumer_id) assert consumer is not None, "We should have crashed in the runtime" try: arg = api.tree_unflatten(consumer.arg_treedef, arrays) consumer.func(arg, **consumer.unpack_kwargs()) # type: ignore[attribute-error] except Exception as e: logging.error("Postponing exception raised in tap function: %s\n%s", str(e), traceback.format_exc()) _outfeed_receiver.num_tap_exceptions += 1 return
def _outfeed_receiver_callback(device, consumer_id, arrays): #logging.vlog( # 2, f"Outfeed received on device {device} for consumer {consumer_id} " + # (" ".join([f"({a.dtype}{a.shape})" for a in arrays]))) consumer = _outfeed_receiver.consumer_registry_by_id.get(consumer_id) assert consumer is not None, "We should have crashed in the runtime" try: arg = api.tree_unflatten(consumer.arg_treedef, arrays) consumer.func(arg, consumer.unpack_transforms()) # type: ignore[attribute-error] except Exception as e: if isinstance(e, TypeError): logging.error("The signature host_callback.id_tap uses to calls wrapped " "functions has changed: ``transforms`` was previously " "passed as a keyword argument, but is now passed by " "position.") logging.error("Postponing exception raised in tap function: %s\n%s", str(e), traceback.format_exc()) _outfeed_receiver.num_tap_exceptions += 1 return