예제 #1
0
 def device_receiver_loop(device: XlaDevice) -> XlaDevice:
     """Polls the outfeed for a device in a loop."""
     nonlocal count_tap_exceptions
     while (True):
         consumer_id, arrays = _receive_outfeed(device, receiver_name)
         if _LOGGING:
             logging.info(
                 f"[{receiver_name}:{device}] Outfeed received for consumer {consumer_id} "
                 + (" ".join([f"({a.dtype}{a.shape})" for a in arrays])))
         if consumer_id == _end_consumer:
             assert not arrays
             if _LOGGING:
                 logging.info(
                     f"[{receiver_name}:{device}] Outfeed received END_OUTFEED"
                 )
             return device
         consumer = _consumer_registry_by_id.get(consumer_id)
         if consumer is None:
             logging.error(
                 f"Ignoring received outfeed for unknown tap consumer")
             count_tap_exceptions += 1
             continue  # We need to read the entire outfeed
         try:
             arg = api.tree_unflatten(consumer.arg_treedef, arrays)
             consumer.func(arg, **dict(
                 consumer.kwargs))  # type: ignore[attribute-error]
         except Exception as e:
             logging.error(
                 f"Postponing exception raised in tap function: {str(e)}\n{traceback.format_exc()}"
             )
             count_tap_exceptions += 1
예제 #2
0
def _outfeed_receiver_callback(device, consumer_id, arrays):
  #logging.vlog(
  #    2, f"Outfeed received on device {device} for consumer {consumer_id} " +
  #    (" ".join([f"({a.dtype}{a.shape})" for a in arrays])))
  consumer = _outfeed_receiver.consumer_registry_by_id.get(consumer_id)
  assert consumer is not None, "We should have crashed in the runtime"
  try:
    arg = api.tree_unflatten(consumer.arg_treedef, arrays)
    consumer.func(arg,
                  **consumer.unpack_kwargs())  # type: ignore[attribute-error]
  except Exception as e:
    logging.error("Postponing exception raised in tap function: %s\n%s", str(e),
                  traceback.format_exc())
    _outfeed_receiver.num_tap_exceptions += 1
    return
예제 #3
0
def _outfeed_receiver_callback(device, consumer_id, arrays):
  #logging.vlog(
  #    2, f"Outfeed received on device {device} for consumer {consumer_id} " +
  #    (" ".join([f"({a.dtype}{a.shape})" for a in arrays])))
  consumer = _outfeed_receiver.consumer_registry_by_id.get(consumer_id)
  assert consumer is not None, "We should have crashed in the runtime"
  try:
    arg = api.tree_unflatten(consumer.arg_treedef, arrays)
    consumer.func(arg, consumer.unpack_transforms())  # type: ignore[attribute-error]
  except Exception as e:
    if isinstance(e, TypeError):
      logging.error("The signature host_callback.id_tap uses to calls wrapped "
                    "functions has changed: ``transforms`` was previously "
                    "passed as a keyword argument, but is now passed by "
                    "position.")
    logging.error("Postponing exception raised in tap function: %s\n%s", str(e),
                  traceback.format_exc())
    _outfeed_receiver.num_tap_exceptions += 1
    return