class Promise: def __init__(self): self._event = Event() self._data = None self._exception = None def __repr__(self): res = super().__repr__() if self.is_set(): extra = repr(self._exception) if self._exception else repr( self._data) else: extra = 'unset' return f'<{res[1:-1]} [{extra}]>' def is_set(self): '''Return `True` if the promise is set''' return self._event.is_set() def clear(self): '''Clear the promise''' self._data = None self._exception = None self._event.clear() async def set(self, data): '''Set the promise. Wake all waiting tasks (if any).''' self._data = data await self._event.set() async def get(self): '''Wait for the promise to be set, and return the data. If an exception was set, it will be raised.''' await self._event.wait() if self._exception is not None: raise self._exception return self._data async def __aenter__(self): return self async def __aexit__(self, exc_type, exc, tb): if exc_type is not None: self._exception = exc await self._event.set() return True
class WebsocketPrototype(ABC): __slots__ = ('socket', 'protocol', 'outgoing', 'incoming', 'closure', 'closing') def __init__(self): self.outgoing = Queue() self.incoming = Queue() self.closure = None self.closing = Event() @property def closed(self): return self.closing.is_set() async def send(self, data): if self.closed: raise WebsocketClosedError() await self.outgoing.put(Message(data=data)) async def recv(self): if not self.closed: async with TaskGroup(wait=any) as g: receiver = await g.spawn(self.incoming.get) await g.spawn(self.closing.wait) if g.completed is receiver: return receiver.result async def __aiter__(self): async for msg in self.incoming: yield msg async def close(self, code=1000, reason='Closed.'): await self.outgoing.put(CloseConnection(code=code, reason=reason)) async def _handle_incoming(self): events = self.protocol.events() while not self.closed: try: data = await self.socket.recv(4096) except ConnectionResetError: return await self.closing.set() self.protocol.receive_data(data) try: event = next(events) except StopIteration: # Connection dropped unexpectedly return await self.closing.set() if isinstance(event, CloseConnection): self.closure = event await self.outgoing.put(event.response()) await self.closing.set() elif isinstance(event, Message): await self.incoming.put(event.data) elif isinstance(event, Ping): await self.outgoing.put(event.response()) async def _handle_outgoing(self): async for event in self.outgoing: if event is None or self.protocol.state is ConnectionState.CLOSED: return await self.closing.set() data = self.protocol.send(event) try: await self.socket.sendall(data) if isinstance(data, CloseConnection): self.closure = event return await self.closing.set() except socket.error: return await self.closing.set() async def flow(self, *tasks): async with TaskGroup(tasks=tasks) as ws: incoming = await ws.spawn(self._handle_incoming) outgoing = await ws.spawn(self._handle_outgoing) finished = await ws.next_done() if finished is incoming: await self.outgoing.put(None) await outgoing.join() elif finished in tasks: # Task is finished. # We ask for the outgoing to finish if finished.exception: await self.close(1011, 'Task died prematurely.') else: await self.close() await outgoing.join()