async def write( self, msg: dict, serializers=("cuda", "dask", "pickle", "error"), on_error: str = "message", ): with log_errors(): if self.closed(): raise CommClosedError( "Endpoint is closed -- unable to send message") try: if serializers is None: serializers = ("cuda", "dask", "pickle", "error") # msg can also be a list of dicts when sending batched messages frames = await to_frames( msg, serializers=serializers, on_error=on_error, allow_offload=self.allow_offload, ) nframes = len(frames) cuda_frames = tuple( hasattr(f, "__cuda_array_interface__") for f in frames) sizes = tuple(nbytes(f) for f in frames) cuda_send_frames, send_frames = zip( *((is_cuda, each_frame) for is_cuda, each_frame in zip(cuda_frames, frames) if nbytes(each_frame) > 0)) # Send meta data # Send close flag and number of frames (_Bool, int64) await self.ep.send(struct.pack("?Q", False, nframes)) # Send which frames are CUDA (bool) and # how large each frame is (uint64) await self.ep.send( struct.pack(nframes * "?" + nframes * "Q", *cuda_frames, *sizes)) # Send frames # It is necessary to first synchronize the default stream before start # sending We synchronize the default stream because UCX is not # stream-ordered and syncing the default stream will wait for other # non-blocking CUDA streams. Note this is only sufficient if the memory # being sent is not currently in use on non-blocking CUDA streams. if any(cuda_send_frames): synchronize_stream(0) for each_frame in send_frames: await self.ep.send(each_frame) return sum(sizes) except (ucp.exceptions.UCXBaseException): self.abort() raise CommClosedError( "While writing, the connection was closed")
async def read(self, deserializers=("cuda", "dask", "pickle", "error")): with log_errors(): if deserializers is None: deserializers = ("cuda", "dask", "pickle", "error") try: # Recv meta data # Recv close flag and number of frames (_Bool, int64) msg = host_array(struct.calcsize("?Q")) await self.ep.recv(msg) (shutdown, nframes) = struct.unpack("?Q", msg) if shutdown: # The writer is closing the connection raise CommClosedError("Connection closed by writer") # Recv which frames are CUDA (bool) and # how large each frame is (uint64) header_fmt = nframes * "?" + nframes * "Q" header = host_array(struct.calcsize(header_fmt)) await self.ep.recv(header) header = struct.unpack(header_fmt, header) cuda_frames, sizes = header[:nframes], header[nframes:] except ( ucp.exceptions.UCXCloseError, ucp.exceptions.UCXCanceled, ) + (getattr(ucp.exceptions, "UCXConnectionReset", ()), ): self.abort() raise CommClosedError("Connection closed by writer") else: # Recv frames frames = [ device_array(each_size) if is_cuda else host_array(each_size) for is_cuda, each_size in zip(cuda_frames, sizes) ] cuda_recv_frames, recv_frames = zip( *((is_cuda, each_frame) for is_cuda, each_frame in zip(cuda_frames, frames) if nbytes(each_frame) > 0)) # It is necessary to first populate `frames` with CUDA arrays and synchronize # the default stream before starting receiving to ensure buffers have been allocated if any(cuda_recv_frames): synchronize_stream(0) for each_frame in recv_frames: await self.ep.recv(each_frame) msg = await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers, allow_offload=self.allow_offload, ) return msg
async def read(self, deserializers="ignored"): if self._closed: raise CommClosedError() msg = await self._read_q.get() if msg is _EOF: self._closed = True self._finalizer.detach() raise CommClosedError() if self.deserialize: msg = nested_deserialize(msg) return msg
async def read(self, deserializers=None): stream = self.stream if stream is None: raise CommClosedError() fmt = "Q" fmt_size = struct.calcsize(fmt) try: frames_nbytes = await stream.read_bytes(fmt_size) (frames_nbytes, ) = struct.unpack(fmt, frames_nbytes) frames = host_array(frames_nbytes) for i, j in sliding_window( 2, range(0, frames_nbytes + OPENSSL_MAX_CHUNKSIZE, OPENSSL_MAX_CHUNKSIZE), ): chunk = frames[i:j] chunk_nbytes = len(chunk) n = await stream.read_into(chunk) assert n == chunk_nbytes, (n, chunk_nbytes) except StreamClosedError as e: self.stream = None self._closed = True if not sys.is_finalizing(): convert_stream_closed_error(self, e) except Exception: # Some OSError or a another "low-level" exception. We do not really know what # was already read from the underlying socket, so it is not even safe to retry # here using the same stream. The only safe thing to do is to abort. # (See also GitHub #4133). self.abort() raise else: try: frames = unpack_frames(frames) msg = await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers, allow_offload=self.allow_offload, ) except EOFError: # Frames possibly garbled or truncated by communication error self.abort() raise CommClosedError("aborted stream on truncated data") return msg
def convert_stream_closed_error(obj, exc): """ Re-raise StreamClosedError as CommClosedError. """ if exc.real_error is not None: # The stream was closed because of an underlying OS error exc = exc.real_error if ssl and isinstance(exc, ssl.SSLError): if "UNKNOWN_CA" in exc.reason: raise FatalCommClosedError( f"in {obj}: {exc.__class__.__name__}: {exc}") raise CommClosedError( f"in {obj}: {exc.__class__.__name__}: {exc}") from exc else: raise CommClosedError(f"in {obj}: {exc}") from exc
async def connect(self, address: str, deserialize=True, **connection_args) -> UCX: logger.debug("UCXConnector.connect: %s", address) ip, port = parse_host_port(address) init_once() try: ep = await ucp.create_endpoint(ip, port) except ( ucp.exceptions.UCXCloseError, ucp.exceptions.UCXCanceled, ) + ( getattr(ucp.exceptions, "UCXConnectionReset", ()), getattr(ucp.exceptions, "UCXNotConnected", ()), getattr(ucp.exceptions, "UCXUnreachable", ()), ): # type: ignore raise CommClosedError( "Connection closed before handshake completed") return self.comm_class( ep, local_addr="", peer_addr=self.prefix + address, deserialize=deserialize, )
async def write(self, msg, serializers=None, on_error=None): if self.closed(): raise CommClosedError() # Ensure we feed the queue in the same thread it is read from. self._write_loop.add_callback(self._write_q.put_nowait, msg) return 1
async def read(self, deserializers=None): try: n_frames = await self.handler.q.get() except RuntimeError: # Event loop is closed raise CommClosedError() if n_frames is CommClosedError: raise CommClosedError() else: n_frames = struct.unpack("Q", n_frames)[0] frames = [(await self.handler.q.get()) for _ in range(n_frames)] return await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers, allow_offload=self.allow_offload, )
async def read(self) -> list[bytes]: """Read a single message from the comm.""" # Even if comm is closed, we still yield all received data before # erroring if self._queue is not None: out = await self._queue.get() if out is not _COMM_CLOSED: return out self._queue = None raise CommClosedError("Connection closed")
async def read(self, deserializers=None): try: n_frames = await self.sock.read_message() if n_frames is None: # Connection is closed self.abort() raise CommClosedError() n_frames = struct.unpack("Q", n_frames)[0] except WebSocketClosedError as e: raise CommClosedError(e) frames = [(await self.sock.read_message()) for _ in range(n_frames)] msg = await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers, allow_offload=self.allow_offload, ) return msg
async def read(self, deserializers=None): frames = await self._protocol.read() try: return await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers, allow_offload=self.allow_offload, ) except EOFError: # Frames possibly garbled or truncated by communication error self.abort() raise CommClosedError("aborted stream on truncated data")
def connection_lost(self, exc=None): self._transport = None self._closed_waiter.set_result(None) # Unblock read, if any self._queue.put_nowait(_COMM_CLOSED) # Unblock write, if any if self._paused: waiter = self._drain_waiter if waiter is not None: self._drain_waiter = None if not waiter.done(): waiter.set_exception(CommClosedError("Connection closed"))
async def connect(self, address, deserialize=True, **connection_args): kwargs = self._get_connect_args(**connection_args) try: request = HTTPRequest(f"{self.prefix}{address}", **kwargs) sock = await websocket_connect(request, max_message_size=MAX_MESSAGE_SIZE) if sock.stream.closed() and sock.stream.error: raise StreamClosedError(sock.stream.error) except StreamClosedError as e: convert_stream_closed_error(self, e) except SSLError as err: raise FatalCommClosedError( "TLS expects a `ssl_context` argument of type " "ssl.SSLContext (perhaps check your TLS configuration?)" ) from err except HTTPClientError as e: raise CommClosedError(f"in {self}: {e}") from e return self.comm_class(sock, deserialize=deserialize)
async def write(self, msg, serializers=None, on_error=None): frames = await to_frames( msg, allow_offload=self.allow_offload, serializers=serializers, on_error=on_error, context={ "sender": self.local_info, "recipient": self.remote_info, **self.handshake_options, }, frame_split_size=BIG_BYTES_SHARD_SIZE, ) n = struct.pack("Q", len(frames)) try: await self.sock.write_message(n, binary=True) for frame in frames: await self.sock.write_message(ensure_bytes(frame), binary=True) except WebSocketClosedError as e: raise CommClosedError(e) return sum(map(nbytes, frames))
async def write(self, frames: list[bytes]) -> int: """Write a message to the comm.""" if self.is_closed: raise CommClosedError("Connection closed") elif self._paused: # Wait until there's room in the write buffer drain_waiter = self._drain_waiter = self._loop.create_future() await drain_waiter # Ensure all memoryviews are in single-byte format frames = [ f.cast("B") if isinstance(f, memoryview) else f for f in frames ] nframes = len(frames) frames_nbytes = [len(f) for f in frames] # TODO: the old TCP comm included an extra `msg_nbytes` prefix that # isn't really needed. We include it here for backwards compatibility, # but this could be removed if we ever want to make another breaking # change to the comms. msg_nbytes = sum(frames_nbytes) + (nframes + 1) * 8 header = struct.pack(f"{nframes + 2}Q", msg_nbytes, nframes, *frames_nbytes) if msg_nbytes < 4 * 1024: # Always concatenate small messages buffers = [b"".join([header, *frames])] else: buffers = coalesce_buffers([header, *frames]) if len(buffers) > 1: self._transport.writelines(buffers) else: self._transport.write(buffers[0]) return msg_nbytes
async def write(self, msg, serializers=None, on_error="message"): stream = self.stream if stream is None: raise CommClosedError() frames = await to_frames( msg, allow_offload=self.allow_offload, serializers=serializers, on_error=on_error, context={ "sender": self.local_info, "recipient": self.remote_info, **self.handshake_options, }, frame_split_size=self.max_shard_size, ) frames_nbytes = [nbytes(f) for f in frames] frames_nbytes_total = sum(frames_nbytes) header = pack_frames_prelude(frames) header = struct.pack("Q", nbytes(header) + frames_nbytes_total) + header frames = [header, *frames] frames_nbytes = [nbytes(header), *frames_nbytes] frames_nbytes_total += frames_nbytes[0] if frames_nbytes_total < 2**17: # 128kiB # small enough, send in one go frames = [b"".join(frames)] frames_nbytes = [frames_nbytes_total] try: # trick to enque all frames for writing beforehand for each_frame_nbytes, each_frame in zip(frames_nbytes, frames): if each_frame_nbytes: if stream._write_buffer is None: raise StreamClosedError() if isinstance(each_frame, memoryview): # Make sure that `len(data) == data.nbytes` # See <https://github.com/tornadoweb/tornado/pull/2996> each_frame = memoryview(each_frame).cast("B") stream._write_buffer.append(each_frame) stream._total_write_index += each_frame_nbytes # start writing frames stream.write(b"") except StreamClosedError as e: self.stream = None self._closed = True if not sys.is_finalizing(): convert_stream_closed_error(self, e) except Exception: # Some OSError or a another "low-level" exception. We do not really know # what was already written to the underlying socket, so it is not even safe # to retry here using the same stream. The only safe thing to do is to # abort. (See also GitHub #4133). self.abort() raise return frames_nbytes_total
def ep(self): if self._ep is not None: return self._ep else: raise CommClosedError("UCX Endpoint is closed")