Esempio n. 1
0
    async def write(
        self,
        msg: dict,
        serializers=("cuda", "dask", "pickle", "error"),
        on_error: str = "message",
    ):
        with log_errors():
            if self.closed():
                raise CommClosedError(
                    "Endpoint is closed -- unable to send message")
            try:
                if serializers is None:
                    serializers = ("cuda", "dask", "pickle", "error")
                # msg can also be a list of dicts when sending batched messages
                frames = await to_frames(
                    msg,
                    serializers=serializers,
                    on_error=on_error,
                    allow_offload=self.allow_offload,
                )
                nframes = len(frames)
                cuda_frames = tuple(
                    hasattr(f, "__cuda_array_interface__") for f in frames)
                sizes = tuple(nbytes(f) for f in frames)
                cuda_send_frames, send_frames = zip(
                    *((is_cuda, each_frame)
                      for is_cuda, each_frame in zip(cuda_frames, frames)
                      if nbytes(each_frame) > 0))

                # Send meta data

                # Send close flag and number of frames (_Bool, int64)
                await self.ep.send(struct.pack("?Q", False, nframes))
                # Send which frames are CUDA (bool) and
                # how large each frame is (uint64)
                await self.ep.send(
                    struct.pack(nframes * "?" + nframes * "Q", *cuda_frames,
                                *sizes))

                # Send frames

                # It is necessary to first synchronize the default stream before start
                # sending We synchronize the default stream because UCX is not
                # stream-ordered and syncing the default stream will wait for other
                # non-blocking CUDA streams. Note this is only sufficient if the memory
                # being sent is not currently in use on non-blocking CUDA streams.
                if any(cuda_send_frames):
                    synchronize_stream(0)

                for each_frame in send_frames:
                    await self.ep.send(each_frame)
                return sum(sizes)
            except (ucp.exceptions.UCXBaseException):
                self.abort()
                raise CommClosedError(
                    "While writing, the connection was closed")
Esempio n. 2
0
    async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
        with log_errors():
            if deserializers is None:
                deserializers = ("cuda", "dask", "pickle", "error")

            try:
                # Recv meta data

                # Recv close flag and number of frames (_Bool, int64)
                msg = host_array(struct.calcsize("?Q"))
                await self.ep.recv(msg)
                (shutdown, nframes) = struct.unpack("?Q", msg)

                if shutdown:  # The writer is closing the connection
                    raise CommClosedError("Connection closed by writer")

                # Recv which frames are CUDA (bool) and
                # how large each frame is (uint64)
                header_fmt = nframes * "?" + nframes * "Q"
                header = host_array(struct.calcsize(header_fmt))
                await self.ep.recv(header)
                header = struct.unpack(header_fmt, header)
                cuda_frames, sizes = header[:nframes], header[nframes:]
            except (
                    ucp.exceptions.UCXCloseError,
                    ucp.exceptions.UCXCanceled,
            ) + (getattr(ucp.exceptions, "UCXConnectionReset", ()), ):
                self.abort()
                raise CommClosedError("Connection closed by writer")
            else:
                # Recv frames
                frames = [
                    device_array(each_size)
                    if is_cuda else host_array(each_size)
                    for is_cuda, each_size in zip(cuda_frames, sizes)
                ]
                cuda_recv_frames, recv_frames = zip(
                    *((is_cuda, each_frame)
                      for is_cuda, each_frame in zip(cuda_frames, frames)
                      if nbytes(each_frame) > 0))

                # It is necessary to first populate `frames` with CUDA arrays and synchronize
                # the default stream before starting receiving to ensure buffers have been allocated
                if any(cuda_recv_frames):
                    synchronize_stream(0)

                for each_frame in recv_frames:
                    await self.ep.recv(each_frame)
                msg = await from_frames(
                    frames,
                    deserialize=self.deserialize,
                    deserializers=deserializers,
                    allow_offload=self.allow_offload,
                )
                return msg
Esempio n. 3
0
    async def read(self, deserializers="ignored"):
        if self._closed:
            raise CommClosedError()

        msg = await self._read_q.get()
        if msg is _EOF:
            self._closed = True
            self._finalizer.detach()
            raise CommClosedError()

        if self.deserialize:
            msg = nested_deserialize(msg)
        return msg
Esempio n. 4
0
    async def read(self, deserializers=None):
        stream = self.stream
        if stream is None:
            raise CommClosedError()

        fmt = "Q"
        fmt_size = struct.calcsize(fmt)

        try:
            frames_nbytes = await stream.read_bytes(fmt_size)
            (frames_nbytes, ) = struct.unpack(fmt, frames_nbytes)

            frames = host_array(frames_nbytes)
            for i, j in sliding_window(
                    2,
                    range(0, frames_nbytes + OPENSSL_MAX_CHUNKSIZE,
                          OPENSSL_MAX_CHUNKSIZE),
            ):
                chunk = frames[i:j]
                chunk_nbytes = len(chunk)
                n = await stream.read_into(chunk)
                assert n == chunk_nbytes, (n, chunk_nbytes)
        except StreamClosedError as e:
            self.stream = None
            self._closed = True
            if not sys.is_finalizing():
                convert_stream_closed_error(self, e)
        except Exception:
            # Some OSError or a another "low-level" exception. We do not really know what
            # was already read from the underlying socket, so it is not even safe to retry
            # here using the same stream. The only safe thing to do is to abort.
            # (See also GitHub #4133).
            self.abort()
            raise
        else:
            try:
                frames = unpack_frames(frames)

                msg = await from_frames(
                    frames,
                    deserialize=self.deserialize,
                    deserializers=deserializers,
                    allow_offload=self.allow_offload,
                )
            except EOFError:
                # Frames possibly garbled or truncated by communication error
                self.abort()
                raise CommClosedError("aborted stream on truncated data")
            return msg
Esempio n. 5
0
def convert_stream_closed_error(obj, exc):
    """
    Re-raise StreamClosedError as CommClosedError.
    """
    if exc.real_error is not None:
        # The stream was closed because of an underlying OS error
        exc = exc.real_error
        if ssl and isinstance(exc, ssl.SSLError):
            if "UNKNOWN_CA" in exc.reason:
                raise FatalCommClosedError(
                    f"in {obj}: {exc.__class__.__name__}: {exc}")
        raise CommClosedError(
            f"in {obj}: {exc.__class__.__name__}: {exc}") from exc
    else:
        raise CommClosedError(f"in {obj}: {exc}") from exc
Esempio n. 6
0
 async def connect(self,
                   address: str,
                   deserialize=True,
                   **connection_args) -> UCX:
     logger.debug("UCXConnector.connect: %s", address)
     ip, port = parse_host_port(address)
     init_once()
     try:
         ep = await ucp.create_endpoint(ip, port)
     except (
             ucp.exceptions.UCXCloseError,
             ucp.exceptions.UCXCanceled,
     ) + (
             getattr(ucp.exceptions, "UCXConnectionReset", ()),
             getattr(ucp.exceptions, "UCXNotConnected", ()),
             getattr(ucp.exceptions, "UCXUnreachable", ()),
     ):  # type: ignore
         raise CommClosedError(
             "Connection closed before handshake completed")
     return self.comm_class(
         ep,
         local_addr="",
         peer_addr=self.prefix + address,
         deserialize=deserialize,
     )
Esempio n. 7
0
    async def write(self, msg, serializers=None, on_error=None):
        if self.closed():
            raise CommClosedError()

        # Ensure we feed the queue in the same thread it is read from.
        self._write_loop.add_callback(self._write_q.put_nowait, msg)

        return 1
Esempio n. 8
0
    async def read(self, deserializers=None):
        try:
            n_frames = await self.handler.q.get()
        except RuntimeError:  # Event loop is closed
            raise CommClosedError()

        if n_frames is CommClosedError:
            raise CommClosedError()
        else:
            n_frames = struct.unpack("Q", n_frames)[0]
        frames = [(await self.handler.q.get()) for _ in range(n_frames)]
        return await from_frames(
            frames,
            deserialize=self.deserialize,
            deserializers=deserializers,
            allow_offload=self.allow_offload,
        )
Esempio n. 9
0
 async def read(self) -> list[bytes]:
     """Read a single message from the comm."""
     # Even if comm is closed, we still yield all received data before
     # erroring
     if self._queue is not None:
         out = await self._queue.get()
         if out is not _COMM_CLOSED:
             return out
         self._queue = None
     raise CommClosedError("Connection closed")
Esempio n. 10
0
    async def read(self, deserializers=None):
        try:
            n_frames = await self.sock.read_message()
            if n_frames is None:
                # Connection is closed
                self.abort()
                raise CommClosedError()
            n_frames = struct.unpack("Q", n_frames)[0]
        except WebSocketClosedError as e:
            raise CommClosedError(e)

        frames = [(await self.sock.read_message()) for _ in range(n_frames)]

        msg = await from_frames(
            frames,
            deserialize=self.deserialize,
            deserializers=deserializers,
            allow_offload=self.allow_offload,
        )
        return msg
Esempio n. 11
0
 async def read(self, deserializers=None):
     frames = await self._protocol.read()
     try:
         return await from_frames(
             frames,
             deserialize=self.deserialize,
             deserializers=deserializers,
             allow_offload=self.allow_offload,
         )
     except EOFError:
         # Frames possibly garbled or truncated by communication error
         self.abort()
         raise CommClosedError("aborted stream on truncated data")
Esempio n. 12
0
    def connection_lost(self, exc=None):
        self._transport = None
        self._closed_waiter.set_result(None)

        # Unblock read, if any
        self._queue.put_nowait(_COMM_CLOSED)

        # Unblock write, if any
        if self._paused:
            waiter = self._drain_waiter
            if waiter is not None:
                self._drain_waiter = None
                if not waiter.done():
                    waiter.set_exception(CommClosedError("Connection closed"))
Esempio n. 13
0
 async def connect(self, address, deserialize=True, **connection_args):
     kwargs = self._get_connect_args(**connection_args)
     try:
         request = HTTPRequest(f"{self.prefix}{address}", **kwargs)
         sock = await websocket_connect(request,
                                        max_message_size=MAX_MESSAGE_SIZE)
         if sock.stream.closed() and sock.stream.error:
             raise StreamClosedError(sock.stream.error)
     except StreamClosedError as e:
         convert_stream_closed_error(self, e)
     except SSLError as err:
         raise FatalCommClosedError(
             "TLS expects a `ssl_context` argument of type "
             "ssl.SSLContext (perhaps check your TLS configuration?)"
         ) from err
     except HTTPClientError as e:
         raise CommClosedError(f"in {self}: {e}") from e
     return self.comm_class(sock, deserialize=deserialize)
Esempio n. 14
0
    async def write(self, msg, serializers=None, on_error=None):
        frames = await to_frames(
            msg,
            allow_offload=self.allow_offload,
            serializers=serializers,
            on_error=on_error,
            context={
                "sender": self.local_info,
                "recipient": self.remote_info,
                **self.handshake_options,
            },
            frame_split_size=BIG_BYTES_SHARD_SIZE,
        )
        n = struct.pack("Q", len(frames))
        try:
            await self.sock.write_message(n, binary=True)
            for frame in frames:
                await self.sock.write_message(ensure_bytes(frame), binary=True)
        except WebSocketClosedError as e:
            raise CommClosedError(e)

        return sum(map(nbytes, frames))
Esempio n. 15
0
    async def write(self, frames: list[bytes]) -> int:
        """Write a message to the comm."""
        if self.is_closed:
            raise CommClosedError("Connection closed")
        elif self._paused:
            # Wait until there's room in the write buffer
            drain_waiter = self._drain_waiter = self._loop.create_future()
            await drain_waiter

        # Ensure all memoryviews are in single-byte format
        frames = [
            f.cast("B") if isinstance(f, memoryview) else f for f in frames
        ]

        nframes = len(frames)
        frames_nbytes = [len(f) for f in frames]
        # TODO: the old TCP comm included an extra `msg_nbytes` prefix that
        # isn't really needed. We include it here for backwards compatibility,
        # but this could be removed if we ever want to make another breaking
        # change to the comms.
        msg_nbytes = sum(frames_nbytes) + (nframes + 1) * 8
        header = struct.pack(f"{nframes + 2}Q", msg_nbytes, nframes,
                             *frames_nbytes)

        if msg_nbytes < 4 * 1024:
            # Always concatenate small messages
            buffers = [b"".join([header, *frames])]
        else:
            buffers = coalesce_buffers([header, *frames])

        if len(buffers) > 1:
            self._transport.writelines(buffers)
        else:
            self._transport.write(buffers[0])

        return msg_nbytes
Esempio n. 16
0
    async def write(self, msg, serializers=None, on_error="message"):
        stream = self.stream
        if stream is None:
            raise CommClosedError()

        frames = await to_frames(
            msg,
            allow_offload=self.allow_offload,
            serializers=serializers,
            on_error=on_error,
            context={
                "sender": self.local_info,
                "recipient": self.remote_info,
                **self.handshake_options,
            },
            frame_split_size=self.max_shard_size,
        )
        frames_nbytes = [nbytes(f) for f in frames]
        frames_nbytes_total = sum(frames_nbytes)

        header = pack_frames_prelude(frames)
        header = struct.pack("Q",
                             nbytes(header) + frames_nbytes_total) + header

        frames = [header, *frames]
        frames_nbytes = [nbytes(header), *frames_nbytes]
        frames_nbytes_total += frames_nbytes[0]

        if frames_nbytes_total < 2**17:  # 128kiB
            # small enough, send in one go
            frames = [b"".join(frames)]
            frames_nbytes = [frames_nbytes_total]

        try:
            # trick to enque all frames for writing beforehand
            for each_frame_nbytes, each_frame in zip(frames_nbytes, frames):
                if each_frame_nbytes:
                    if stream._write_buffer is None:
                        raise StreamClosedError()

                    if isinstance(each_frame, memoryview):
                        # Make sure that `len(data) == data.nbytes`
                        # See <https://github.com/tornadoweb/tornado/pull/2996>
                        each_frame = memoryview(each_frame).cast("B")

                    stream._write_buffer.append(each_frame)
                    stream._total_write_index += each_frame_nbytes

            # start writing frames
            stream.write(b"")
        except StreamClosedError as e:
            self.stream = None
            self._closed = True
            if not sys.is_finalizing():
                convert_stream_closed_error(self, e)
        except Exception:
            # Some OSError or a another "low-level" exception. We do not really know
            # what was already written to the underlying socket, so it is not even safe
            # to retry here using the same stream. The only safe thing to do is to
            # abort. (See also GitHub #4133).
            self.abort()
            raise

        return frames_nbytes_total
Esempio n. 17
0
 def ep(self):
     if self._ep is not None:
         return self._ep
     else:
         raise CommClosedError("UCX Endpoint is closed")