def test_connection_send_state() -> None: client = WSConnection(CLIENT) assert client.state is ConnectionState.CONNECTING server = h11.Connection(h11.SERVER) server.receive_data(client.send(Request( host="localhost", target="/", ))) headers = normed_header_dict(server.next_event().headers) response = h11.InformationalResponse( status_code=101, headers=[ (b"connection", b"Upgrade"), (b"upgrade", b"WebSocket"), ( b"Sec-WebSocket-Accept", generate_accept_token(headers[b"sec-websocket-key"]), ), ], ) client.receive_data(server.send(response)) assert len(list(client.events())) == 1 assert client.state is ConnectionState.OPEN # type: ignore # https://github.com/python/mypy/issues/9005 with pytest.raises(LocalProtocolError) as excinfo: client.send(Request(host="localhost", target="/")) client.receive_data(b"foobar") assert len(list(client.events())) == 1
class MockWebsocketConnection: def __init__( self, path: str, event_loop: asyncio.AbstractEventLoop, *, framework: Type[ASGIFramework] = EchoFramework, ) -> None: self.transport = MockTransport() self.server = WebsocketServer( # type: ignore framework, event_loop, Config(), self.transport) self.connection = WSConnection(ConnectionType.CLIENT) self.server.data_received( self.connection.send(Request(target=path, host="hypercorn"))) async def send(self, data: AnyStr) -> None: self.server.data_received(self.connection.send(Message(data=data))) await asyncio.sleep(0) # Allow the server to respond async def receive(self) -> List[Message]: await self.transport.updated.wait() self.connection.receive_data(self.transport.data) self.transport.clear() return [event for event in self.connection.events()] def close(self) -> None: self.server.data_received( self.connection.send(CloseConnection(code=1000)))
def update_reports(server, agent): uri = urlparse(server + '/updateReports?agent=%s' % agent) connection = WSConnection(CLIENT) sock = socket.socket() sock.connect((uri.hostname, uri.port or 80)) sock.sendall( connection.send( Request(host=uri.netloc, target='%s?%s' % (uri.path, uri.query)))) closed = False while not closed: data = sock.recv(65535) connection.receive_data(data) for event in connection.events(): if isinstance(event, AcceptConnection): sock.sendall( connection.send( CloseConnection(code=CloseReason.NORMAL_CLOSURE))) try: sock.close() except CONNECTION_EXCEPTIONS: pass finally: closed = True
def handle_todo_post_save(sender, instance, created, **kwargs): if not hasattr(sender, 'APP_PORT'): return conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) conn.connect(('localhost', int(sender.APP_PORT))) ws = WSConnection(ConnectionType.CLIENT) print("hello") net_send( ws.send(Request( host=f"localhost:{sender.APP_PORT}", target=f"ws/todos")), conn ) net_recv(ws, conn) handle_events(ws) net_send(ws.send(Message(data=str(instance.pk))), conn) net_recv(ws, conn) handle_events(ws) net_send(ws.send(CloseConnection(code=1000)), conn) net_recv(ws, conn) conn.shutdown(socket.SHUT_WR) net_recv(ws, conn) del sender.APP_PORT
def get_case_count(server): uri = urlparse(server + '/getCaseCount') connection = WSConnection(CLIENT) sock = socket.socket() sock.connect((uri.hostname, uri.port or 80)) sock.sendall(connection.send(Request(host=uri.netloc, target=uri.path))) case_count = None while case_count is None: data = sock.recv(65535) connection.receive_data(data) data = "" out_data = b"" for event in connection.events(): if isinstance(event, TextMessage): data += event.data if event.message_finished: case_count = json.loads(data) out_data += connection.send(CloseConnection(code=CloseReason.NORMAL_CLOSURE)) try: sock.sendall(out_data) except CONNECTION_EXCEPTIONS: break sock.close() return case_count
class MockWebsocketConnection: def __init__(self, path: str, *, framework: ASGIFramework = echo_framework) -> None: self.client_stream, server_stream = trio.testing.memory_stream_pair() server_stream.socket = MockSocket() self.server = WebsocketServer(framework, Config(), server_stream) self.connection = WSConnection(ConnectionType.CLIENT) self.server.connection.receive_data( self.connection.send(Request(target=path, host="hypercorn"))) async def send(self, data: AnyStr) -> None: await self.client_stream.send_all( self.connection.send(Message(data=data))) await trio.sleep(0) # Allow the server to respond async def receive(self) -> List[Message]: data = await self.client_stream.receive_some(2**16) self.connection.receive_data(data) return [event for event in self.connection.events()] async def close(self) -> None: await self.client_stream.send_all( self.connection.send(CloseConnection(code=1000)))
def websocket(request): # The underlying socket must be provided by the server. Gunicorn and # Werkzeug's dev server are known to support this. stream = request.environ.get("werkzeug.socket") if stream is None: stream = request.environ.get("gunicorn.socket") if stream is None: raise InternalServerError() # Initialize the wsproto connection. Need to recreate the request # data that was read by the WSGI server already. ws = WSConnection(ConnectionType.SERVER) in_data = b"GET %s HTTP/1.1\r\n" % request.path.encode("utf8") for header, value in request.headers.items(): in_data += f"{header}: {value}\r\n".encode() in_data += b"\r\n" ws.receive_data(in_data) running = True while True: out_data = b"" for event in ws.events(): if isinstance(event, WSRequest): out_data += ws.send(AcceptConnection()) elif isinstance(event, CloseConnection): out_data += ws.send(event.response()) running = False elif isinstance(event, Ping): out_data += ws.send(event.response()) elif isinstance(event, TextMessage): # echo the incoming message back to the client if event.data == "quit": out_data += ws.send( CloseConnection(CloseReason.NORMAL_CLOSURE, "bye") ) running = False else: out_data += ws.send(Message(data=event.data)) if out_data: stream.send(out_data) if not running: break in_data = stream.recv(4096) ws.receive_data(in_data) # The connection will be closed at this point, but WSGI still # requires a response. return Response("", status=204)
def handle_connection(stream: socket.socket) -> None: """ Handle a connection. The server operates a request/response cycle, so it performs a synchronous loop: 1) Read data from network into wsproto 2) Get new events and handle them 3) Send data from wsproto to network :param stream: a socket stream """ ws = WSConnection(ConnectionType.SERVER) running = True while running: # 1) Read data from network in_data = stream.recv(RECEIVE_BYTES) print("Received {} bytes".format(len(in_data))) ws.receive_data(in_data) # 2) Get new events and handle them out_data = b"" for event in ws.events(): if isinstance(event, Request): # Negotiate new WebSocket connection print("Accepting WebSocket upgrade") out_data += ws.send(AcceptConnection()) elif isinstance(event, CloseConnection): # Print log message and break out print("Connection closed: code={} reason={}".format( event.code, event.reason)) out_data += ws.send(event.response()) running = False elif isinstance(event, TextMessage): # Reverse text and send it back to wsproto print("Received request and sending response") out_data += ws.send(Message(data=event.data[::-1])) elif isinstance(event, Ping): # wsproto handles ping events for you by placing a pong frame in # the outgoing buffer. You should not call pong() unless you want to # send an unsolicited pong frame. print("Received ping and sending pong") out_data += ws.send(event.response()) else: print(f"Unknown event: {event!r}") # 4) Send data from wsproto to network print("Sending {} bytes".format(len(out_data))) stream.send(out_data)
def wsproto_demo(host, port): ''' Demonstrate wsproto: 0) Open TCP connection 1) Negotiate WebSocket opening handshake 2) Send a message and display response 3) Send ping and display pong 4) Negotiate WebSocket closing handshake :param stream: a socket stream ''' # 0) Open TCP connection print('Connecting to {}:{}'.format(host, port)) conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) conn.connect((host, port)) # 1) Negotiate WebSocket opening handshake print('Opening WebSocket') ws = WSConnection(ConnectionType.CLIENT) # Because this is a client WebSocket, we need to initiate the connection # handshake by sending a Request event. net_send(ws.send(Request(host=host, target='server')), conn) net_recv(ws, conn) handle_events(ws) # 2) Send a message and display response message = "wsproto is great" print('Sending message: {}'.format(message)) net_send(ws.send(Message(data=message)), conn) net_recv(ws, conn) handle_events(ws) # 3) Send ping and display pong payload = b"table tennis" print('Sending ping: {}'.format(payload)) net_send(ws.send(Ping(payload=payload)), conn) net_recv(ws, conn) handle_events(ws) # 4) Negotiate WebSocket closing handshake print('Closing WebSocket') net_send(ws.send(CloseConnection(code=1000, reason='sample reason')), conn) # After sending the closing frame, we won't get any more events. The server # should send a reply and then close the connection, so we need to receive # twice: net_recv(ws, conn) conn.shutdown(socket.SHUT_WR) net_recv(ws, conn)
def _make_handshake( request_headers: Headers, accept_headers: Optional[Headers] = None, subprotocol: Optional[str] = None, extensions: Optional[List[Extension]] = None, ) -> Tuple[h11.InformationalResponse, bytes]: client = h11.Connection(h11.CLIENT) server = WSConnection(SERVER) nonce = generate_nonce() server.receive_data( client.send( h11.Request( method="GET", target="/", headers=[ (b"Host", b"localhost"), (b"Connection", b"Keep-Alive, Upgrade"), (b"Upgrade", b"WebSocket"), (b"Sec-WebSocket-Version", b"13"), (b"Sec-WebSocket-Key", nonce), ] + request_headers, ))) client.receive_data( server.send( AcceptConnection( extra_headers=accept_headers or [], subprotocol=subprotocol, extensions=extensions or [], ))) event = client.next_event() return event, nonce
def _make_handshake( response_status, response_headers, subprotocols=None, extensions=None, auto_accept_key=True, ): client = WSConnection(CLIENT) server = h11.Connection(h11.SERVER) server.receive_data( client.send( Request( host="localhost", target="/", subprotocols=subprotocols or [], extensions=extensions or [], ) ) ) request = server.next_event() if auto_accept_key: full_request_headers = normed_header_dict(request.headers) response_headers.append( ( b"Sec-WebSocket-Accept", generate_accept_token(full_request_headers[b"sec-websocket-key"]), ) ) response = h11.InformationalResponse( status_code=response_status, headers=response_headers ) client.receive_data(server.send(response)) return list(client.events())
class Websocket(WebsocketPrototype): """Server-side websocket running a handler parallel to the I/O. """ def __init__(self, socket): super().__init__() self.socket = socket self.protocol = WSConnection(ConnectionType.SERVER) async def upgrade(self, request): data = '{} {} HTTP/1.1\r\n'.format(request.method, request.url) data += '\r\n'.join(('{}: {}'.format(k, v) for k, v in request.headers.items())) + '\r\n\r\n' data = data.encode() try: self.protocol.receive_data(data) except RemoteProtocolError: raise HTTPError(HTTPStatus.BAD_REQUEST) else: event = next(self.protocol.events()) if not isinstance(event, Request): raise HTTPError(HTTPStatus.BAD_REQUEST) data = self.protocol.send(AcceptConnection()) await self.socket.sendall(data)
def _make_handshake_rejection( status_code: int, body: Optional[bytes] = None ) -> List[Event]: client = h11.Connection(h11.CLIENT) server = WSConnection(SERVER) nonce = generate_nonce() server.receive_data( client.send( h11.Request( method="GET", target="/", headers=[ (b"Host", b"localhost"), (b"Connection", b"Keep-Alive, Upgrade"), (b"Upgrade", b"WebSocket"), (b"Sec-WebSocket-Version", b"13"), (b"Sec-WebSocket-Key", nonce), ], ) ) ) if body is not None: client.receive_data( server.send( RejectConnection( headers=[(b"content-length", b"%d" % len(body))], status_code=status_code, has_body=True, ) ) ) client.receive_data(server.send(RejectData(data=body))) else: client.receive_data(server.send(RejectConnection(status_code=status_code))) events = [] while True: event = client.next_event() events.append(event) if isinstance(event, h11.EndOfMessage): return events
def new_conn(sock: socket.socket) -> None: global count print("test_server.py received connection {}".format(count)) count += 1 ws = WSConnection(SERVER) closed = False while not closed: try: data: Optional[bytes] = sock.recv(65535) except socket.error: data = None ws.receive_data(data or None) outgoing_data = b"" for event in ws.events(): if isinstance(event, Request): outgoing_data += ws.send( AcceptConnection(extensions=[PerMessageDeflate()])) elif isinstance(event, Message): outgoing_data += ws.send( Message(data=event.data, message_finished=event.message_finished)) elif isinstance(event, Ping): outgoing_data += ws.send(event.response()) elif isinstance(event, CloseConnection): closed = True if ws.state is not ConnectionState.CLOSED: outgoing_data += ws.send(event.response()) if not data: closed = True try: sock.sendall(outgoing_data) except socket.error: closed = True sock.close()
def _make_handshake_rejection(status_code, body=None): client = WSConnection(CLIENT) server = h11.Connection(h11.SERVER) server.receive_data(client.send(Request(host="localhost", target="/"))) headers = [] if body is not None: headers.append(("Content-Length", str(len(body)))) client.receive_data( server.send(h11.Response(status_code=status_code, headers=headers)) ) if body is not None: client.receive_data(server.send(h11.Data(data=body))) client.receive_data(server.send(h11.EndOfMessage())) return list(client.events())
class Websocket(WebsocketPrototype): def __init__(self): super().__init__() self.socket = curio.socket.socket(curio.socket.AF_INET, curio.socket.SOCK_STREAM) self.protocol = WSConnection(ConnectionType.CLIENT) async def connect(self, path, host, port): await self.socket.connect((host, port)) request = Request(host=f'{host}:{port}', target=path) await self.socket.sendall(self.protocol.send(request)) upgrade_response = await self.socket.recv(8096) self.protocol.receive_data(upgrade_response) event = next(self.protocol.events()) if not isinstance(event, AcceptConnection): raise Exception('Websocket handshake failed.')
class Websocket: _scope = None _sock = None _connection = None def __init__(self): self._byte_buffer = BytesIO() self._string_buffer = StringIO() self._closed = False async def __ainit__(self, addr, path: str, headers: Optional[List] = None, subprotocols=None, **connect_kw): sock = await anyio.connect_tcp(*addr, **connect_kw) await self.start_client( sock, addr, path=path, headers=headers, subprotocols=subprotocols) async def start_client(self, sock: anyio.abc.SocketStream, addr, path: str, headers: Optional[List] = None, subprotocols: Optional[List[str]] = None): """Start a client WS connection on this socket. Returns: the AcceptConnection message. """ self._sock = sock self._connection = WSConnection(ConnectionType.CLIENT) if headers is None: headers = [] if subprotocols is None: subprotocols = [] data = self._connection.send( Request( host=addr[0], target=path, extra_headers=headers, subprotocols=subprotocols)) await self._sock.send_all(data) assert self._scope is None self._scope = True try: event = await self._next_event() if not isinstance(event, AcceptConnection): raise ConnectionError("Failed to establish a connection", event) return event finally: self._scope = None async def start_server(self, sock: anyio.abc.SocketStream, filter=None): # pylint: disable=W0622 """Start a server WS connection on this socket. Filter: an async callable that gets passed the initial Request. It may return an AcceptConnection message, a bool, or a string (the subprotocol to use). Returns: the Request message. """ assert self._scope is None self._scope = True self._sock = sock self._connection = WSConnection(ConnectionType.SERVER) try: event = await self._next_event() if not isinstance(event, Request): raise ConnectionError("Failed to establish a connection", event) msg = None if filter is not None: msg = await filter(event) if not msg: msg = RejectConnection() elif msg is True: msg = None elif isinstance(msg, str): msg = AcceptConnection(subprotocol=msg) if not msg: msg = AcceptConnection(subprotocol=event.subprotocols[0]) data = self._connection.send(msg) await self._sock.send_all(data) if not isinstance(msg, AcceptConnection): raise ConnectionError("Not accepted", msg) finally: self._scope = None async def _next_event(self): """ Gets the next event. """ while True: for event in self._connection.events(): if isinstance(event, Message): # check if we need to buffer if event.message_finished: return self._wrap_data(self._gather_buffers(event)) self._buffer(event) break # exit for loop else: return event data = await self._sock.receive_some(4096) if not data: return CloseConnection(code=500, reason="Socket closed") self._connection.receive_data(data) async def close(self, code: int = 1006, reason: str = "Connection closed"): """ Closes the websocket. """ if self._closed: return self._closed = True if self._scope is not None: await self._scope.cancel() # cancel any outstanding listeners data = self._connection.send(CloseConnection(code=code, reason=reason)) await self._sock.send_all(data) # No, we don't wait for the correct reply await self._sock.close() async def send(self, data: Union[bytes, str], final: bool = True): """ Sends some data down the connection. """ MsgType = TextMessage if isinstance(data, str) else BytesMessage data = MsgType(data=data, message_finished=final) data = self._connection.send(event=data) await self._sock.send_all(data) def _buffer(self, event: Message): """ Buffers an event, if applicable. """ if isinstance(event, BytesMessage): self._byte_buffer.write(event.data) elif isinstance(event, TextMessage): self._string_buffer.write(event.data) def _gather_buffers(self, event: Message): """ Gathers all the data from a buffer. """ if isinstance(event, BytesMessage): buf = self._byte_buffer else: buf = self._string_buffer # yay for code shortening buf.write(event.data) buf.seek(0) data = buf.read() buf.seek(0) buf.truncate() return data @staticmethod def _wrap_data(data: Union[str, bytes]): """ Wraps data into the right event. """ MsgType = TextMessage if isinstance(data, str) else BytesMessage return MsgType(data=data, frame_finished=True, message_finished=True) async def __aiter__(self): async with anyio.open_cancel_scope() as scope: if self._scope is not None: raise RuntimeError( "Only one task may iterate on this web socket") self._scope = scope try: while True: msg = await self._next_event() if isinstance(msg, CloseConnection): return yield msg finally: self._scope = None
sock.sendall(out_data) except CONNECTION_EXCEPTIONS: break sock.close() return case_count def run_case(server, case, agent): uri = urlparse(server + '/runCase?case=%d&agent=%s' % (case, agent)) connection = WSConnection(CLIENT) sock = socket.socket() sock.connect((uri.hostname, uri.port or 80)) sock.sendall( connection.send(Request( host=uri.netloc, target='%s?%s' % (uri.path, uri.query), extensions=[PerMessageDeflate()], )) ) closed = False while not closed: try: data = sock.recv(65535) except CONNECTION_EXCEPTIONS: data = None connection.receive_data(data or None) out_data = b"" for event in connection.events(): if isinstance(event, Message): out_data += connection.send(Message(data=event.data, message_finished=event.message_finished)) elif isinstance(event, Ping):
def test_protocol_error(): client = WSConnection(CLIENT) client.send(Request(host="localhost", target="/")) with pytest.raises(RemoteProtocolError) as excinfo: client.receive_data(b"broken nonsense\r\n\r\n") assert str(excinfo.value) == "Bad HTTP message"
def _make_connection_request(request): # type: (Request) -> h11.Request client = WSConnection(CLIENT) server = h11.Connection(h11.SERVER) server.receive_data(client.send(request)) return server.next_event()
class WebsocketServer(HTTPServer, WebsocketMixin): def __init__( self, app: Type[ASGIFramework], loop: asyncio.AbstractEventLoop, config: Config, transport: asyncio.BaseTransport, *, upgrade_request: Optional[h11.Request] = None, ) -> None: super().__init__(loop, config, transport, "wsproto") self.stop_keep_alive_timeout() self.app = app self.connection = WSConnection(ConnectionType.SERVER) self.app_queue: asyncio.Queue = asyncio.Queue() self.response: Optional[dict] = None self.scope: Optional[dict] = None self.state = ASGIWebsocketState.HANDSHAKE self.task: Optional[asyncio.Future] = None self.buffer = WebsocketBuffer(self.config.websocket_max_message_size) if upgrade_request is not None: self.connection.initiate_upgrade_connection( upgrade_request.headers, upgrade_request.target) self.handle_events() def connection_lost(self, error: Optional[Exception]) -> None: if error is not None: self.app_queue.put_nowait({"type": "websocket.disconnect"}) def eof_received(self) -> bool: self.data_received(None) return True def data_received(self, data: Optional[bytes]) -> None: self.connection.receive_data(data) self.handle_events() def handle_events(self) -> None: for event in self.connection.events(): if isinstance(event, Request): self.task = self.loop.create_task(self.handle_websocket(event)) self.task.add_done_callback(self.maybe_close) elif isinstance(event, Message): try: self.buffer.extend(event) except FrameTooLarge: self.write( self.connection.send( CloseConnection(code=CloseReason.MESSAGE_TOO_BIG))) self.app_queue.put_nowait({"type": "websocket.disconnect"}) self.close() break if event.message_finished: self.app_queue.put_nowait(self.buffer.to_message()) self.buffer.clear() elif isinstance(event, Ping): self.write(self.connection.send(event.response())) elif isinstance(event, CloseConnection): if self.connection.state == ConnectionState.REMOTE_CLOSING: self.write(self.connection.send(event.response())) self.app_queue.put_nowait({"type": "websocket.disconnect"}) self.close() break def maybe_close(self, future: asyncio.Future) -> None: # Close the connection iff a HTTP response was sent if self.state == ASGIWebsocketState.HTTPCLOSED: self.close() async def asend(self, event: Event) -> None: await self.drain() self.write(self.connection.send(event)) async def asgi_put(self, message: dict) -> None: await self.app_queue.put(message) async def asgi_receive(self) -> dict: """Called by the ASGI instance to receive a message.""" return await self.app_queue.get() @property def scheme(self) -> str: return "wss" if self.ssl_info is not None else "ws"
class WebsocketServer(HTTPServer, WebsocketMixin): def __init__( self, app: ASGIFramework, config: Config, stream: trio.abc.Stream, *, upgrade_request: Optional[h11.Request] = None, ) -> None: super().__init__(stream, "wsproto") self.app = app self.config = config self.connection = WSConnection(ConnectionType.SERVER) self.response: Optional[dict] = None self.scope: Optional[dict] = None self.send_lock = trio.Lock() self.state = ASGIWebsocketState.HANDSHAKE self.buffer = WebsocketBuffer(self.config.websocket_max_message_size) self.app_send_channel, self.app_receive_channel = trio.open_memory_channel( 10) if upgrade_request is not None: self.connection.initiate_upgrade_connection( upgrade_request.headers, upgrade_request.target) async def handle_connection(self) -> None: try: request = await self.read_request() async with trio.open_nursery() as nursery: nursery.start_soon(self.read_messages) await self.handle_websocket(request) if self.state == ASGIWebsocketState.HTTPCLOSED: raise MustCloseError() except (trio.BrokenResourceError, trio.ClosedResourceError): await self.asgi_put({"type": "websocket.disconnect"}) except MustCloseError: pass finally: await self.aclose() async def read_request(self) -> Request: for event in self.connection.events(): if isinstance(event, Request): return event async def read_messages(self) -> None: while True: data = await self.stream.receive_some(MAX_RECV) if data == b"": data = None # wsproto expects None rather than b"" for EOF self.connection.receive_data(data) for event in self.connection.events(): if isinstance(event, Message): try: self.buffer.extend(event) except FrameTooLarge: await self.asend( CloseConnection(code=CloseReason.MESSAGE_TOO_BIG)) await self.asgi_put({"type": "websocket.disconnect"}) raise MustCloseError() if event.message_finished: await self.asgi_put(self.buffer.to_message()) self.buffer.clear() elif isinstance(event, Ping): await self.asend(event.response()) elif isinstance(event, CloseConnection): if self.connection.state == ConnectionState.REMOTE_CLOSING: await self.asend(event.response()) await self.asgi_put({"type": "websocket.disconnect"}) raise MustCloseError() async def asend(self, event: Event) -> None: async with self.send_lock: await self.stream.send_all(self.connection.send(event)) async def asgi_put(self, message: dict) -> None: await self.app_send_channel.send(message) async def asgi_receive(self) -> dict: return await self.app_receive_channel.receive() @property def scheme(self) -> str: return "wss" if self._is_ssl else "ws"
def wsproto_demo(host, port): ''' Demonstrate wsproto: 0) Open TCP connection 1) Negotiate WebSocket opening handshake 2) Send a message and display response 3) Send ping and display pong 4) Negotiate WebSocket closing handshake :param stream: a socket stream ''' # 0) Open TCP connection print('Connecting to {}:{}'.format(host, port)) conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) conn.connect((host, port)) # 1) Negotiate WebSocket opening handshake print('Opening WebSocket') ws = WSConnection(ConnectionType.CLIENT) net_send(ws.send(Request(host=host, target='server')), conn) net_recv(ws, conn) # events is a generator that yields websocket event objects. Usually you # would say `for event in ws.events()`, but the synchronous nature of this # client requires us to use next(event) instead so that we can interleave # the network I/O. It will raise StopIteration when it runs out of events # (i.e. needs more network data), but since this script is synchronous, we # will explicitly resume the generator whenever we have new network data. events = ws.events() # Because this is a client WebSocket, wsproto has automatically queued up # a handshake, and we need to send it and wait for a response. event = next(events) if isinstance(event, AcceptConnection): print('WebSocket negotiation complete') else: raise Exception('Expected AcceptConnection event!') # 2) Send a message and display response message = "wsproto is great" print('Sending message: {}'.format(message)) net_send(ws.send(Message(data=message)), conn) net_recv(ws, conn) event = next(events) if isinstance(event, TextMessage): print('Received message: {}'.format(event.data)) else: raise Exception('Expected TextMessage event!') # 3) Send ping and display pong payload = b"table tennis" print('Sending ping: {}'.format(payload)) net_send(ws.send(Ping(payload=payload)), conn) net_recv(ws, conn) event = next(events) if isinstance(event, Pong): print('Received pong: {}'.format(event.payload)) else: raise Exception('Expected Pong event!') # 4) Negotiate WebSocket closing handshake print('Closing WebSocket') net_send(ws.send(CloseConnection(code=1000, reason='sample reason')), conn) # After sending the closing frame, we won't get any more events. The server # should send a reply and then close the connection, so we need to receive # twice: net_recv(ws, conn) conn.shutdown(socket.SHUT_WR) net_recv(ws, conn)
class Base: def __init__(self, sock=None, connection_type=None, receive_bytes=4096, thread_class=threading.Thread, event_class=threading.Event): self.sock = sock self.receive_bytes = receive_bytes self.input_buffer = [] self.event = event_class() self.connected = False self.is_server = (connection_type == ConnectionType.SERVER) self.ws = WSConnection(connection_type) self.handshake() self.thread = thread_class(target=self._thread) self.thread.start() self.event.wait() self.event.clear() def handshake(self): """To be implemented by subclasses.""" pass def send(self, data): if not self.connected: raise ConnectionClosed() if isinstance(data, bytes): out_data = self.ws.send(Message(data=data)) else: out_data = self.ws.send(TextMessage(data=str(data))) self.sock.send(out_data) def receive(self, timeout=None): while self.connected and not self.input_buffer: if not self.event.wait(timeout=timeout): return None self.event.clear() if not self.connected: raise ConnectionClosed() return self.input_buffer.pop(0) def close(self, reason=None, message=None): out_data = self.ws.send( CloseConnection(reason or CloseReason.NORMAL_CLOSURE, message)) try: self.sock.send(out_data) except BrokenPipeError: pass def _thread(self): self.connected = self._handle_events() self.event.set() while self.connected: try: in_data = self.sock.recv(self.receive_bytes) except (OSError, ConnectionResetError): self.connected = False self.event.set() break self.ws.receive_data(in_data) self.connected = self._handle_events() def _handle_events(self): keep_going = True out_data = b'' try: for event in self.ws.events(): if isinstance(event, Request): out_data += self.ws.send(AcceptConnection()) elif isinstance(event, CloseConnection): if self.is_server: out_data += self.ws.send(event.response()) self.event.set() keep_going = False elif isinstance(event, Ping): out_data += self.ws.send(event.response()) elif isinstance(event, TextMessage): self.input_buffer.append(event.data) self.event.set() elif isinstance(event, BytesMessage): self.input_buffer.append(event.data) self.event.set() if out_data: self.sock.send(out_data) return keep_going except: return False
class Client: _callbacks: Dict[EventType, Callable] = {} receive_bytes: int = 65535 buffer_size: int = io.DEFAULT_BUFFER_SIZE # noinspection PyTypeChecker def __init__(self, connect_uri: str, headers: Headers = None, extensions: List[str] = None, sub_protocols: List[str] = None): self._check_ws_headers(headers) self._check_list_argument('extensions', extensions) self._check_list_argument('sub_protocols', sub_protocols) self._sock: socket = None self._ws: WSConnection = None # wsproto does not seem to like empty path, so we provide an arbitrary one self._default_path = '/path' self._running = True self._handshake_finished = AsyncResult() host, port, path = self._get_connect_information(connect_uri) self._establish_tcp_connection(host, port) self._establish_websocket_handshake(host, path, headers, extensions, sub_protocols) self._green = spawn(self._run) @staticmethod def _check_ws_headers(headers: Headers) -> None: if headers is None: return error_message = 'headers must of a list of tuples of the form [(bytes, bytes), ..]' if not isinstance(headers, list): raise TypeError(error_message) try: for key, value in headers: if not isinstance(key, bytes) or not isinstance(value, bytes): raise TypeError(error_message) except ValueError: # in case it is not a list of tuples raise TypeError(error_message) @staticmethod def _check_list_argument(name: str, ws_argument: List[str]) -> None: if ws_argument is None: return error_message = f'{name} must be a list of strings' if not isinstance(ws_argument, list): raise TypeError(error_message) for item in ws_argument: if not isinstance(item, str): raise TypeError(error_message) def _get_connect_information(self, connect_uri: str) -> Tuple[str, int, str]: if not isinstance(connect_uri, str): raise TypeError('Your uri must be a string') regex = re.match(r'ws://(\w+)(:\d+)?(/\w+)?', connect_uri) if not regex: raise ValueError( 'Your uri must follow the syntax ws://<host>[:port][/path]') host = regex.group(1) port = int(regex.group(2)[1:]) if regex.group(2) is not None else 80 path = regex.group(3)[1:] if regex.group( 3) is not None else self._default_path return host, port, path @staticmethod def _check_callable(method: str, callback: Callable) -> None: if not isinstance(callback, callable): raise TypeError(f'{method} callback must be a callable') @classmethod def _on_callback(cls, event_type: EventType, func: Callback) -> Callback: cls._callbacks[event_type] = func return func @classmethod def on_connect(cls, func: EventCallback) -> EventCallback: return cls._on_callback(EventType.CONNECT, func) @classmethod def on_disconnect(cls, func: EventCallback) -> EventCallback: return cls._on_callback(EventType.DISCONNECT, func) @classmethod def on_ping(cls, func: BytesCallback) -> BytesCallback: return cls._on_callback(EventType.PING, func) @classmethod def on_pong(cls, func: BytesCallback) -> BytesCallback: return cls._on_callback(EventType.PONG, func) @classmethod def on_text_message(cls, func: StrCallback) -> StrCallback: return cls._on_callback(EventType.TEXT_MESSAGE, func) @classmethod def on_json_message(cls, func: JsonCallback) -> JsonCallback: return cls._on_callback(EventType.JSON_MESSAGE, func) @classmethod def on_binary_message(cls, func: BytesCallback) -> BytesCallback: return cls._on_callback(EventType.BINARY_MESSAGE, func) def _establish_tcp_connection(self, host: str, port: int) -> None: self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._sock.connect((host, port)) def _establish_websocket_handshake(self, host: str, path: str, headers: Headers, extensions: List[str], sub_protocols: List[str]) -> None: self._ws = WSConnection(ConnectionType.CLIENT) headers = headers if headers is not None else [] extensions = extensions if extensions is not None else [] sub_protocols = sub_protocols if sub_protocols is not None else [] request = Request(host=host, target=path, extra_headers=headers, extensions=extensions, subprotocols=sub_protocols) self._sock.sendall(self._ws.send(request)) def _handle_accept(self, event: AcceptConnection) -> None: self._handshake_finished.set() if EventType.CONNECT in self._callbacks: self._callbacks[EventType.CONNECT](self, event) def _handle_reject(self, event: RejectConnection) -> None: self._handshake_finished.set_exception( ConnectionRejectedError(event.status_code, event.headers, b'')) self._running = False def _handle_reject_data(self, event: RejectData, data: bytearray, status_code: int, headers: Headers) -> None: data.extend(event.data) if event.body_finished: self._handshake_finished.set_exception( ConnectionRejectedError(status_code, headers, data)) self._running = False def _handle_close(self, event: CloseConnection) -> None: self._running = False if EventType.DISCONNECT in self._callbacks: self._callbacks[EventType.DISCONNECT](event) # if the server sends first a close connection we need to reply with another one if self._ws.state is ConnectionState.REMOTE_CLOSING: self._sock.sendall(self._ws.send(event.response())) def _handle_ping(self, event: Ping) -> None: if EventType.PING in self._callbacks: self._callbacks[EventType.PING](event.payload) self._sock.sendall(self._ws.send(event.response())) def _handle_pong(self, event: Pong) -> None: if EventType.PONG in self._callbacks: self._callbacks[EventType.PONG](event.payload) def _handle_text_or_json_message(self, event: TextMessage, text_message: List[str]) -> None: text_message.append(event.data) if event.message_finished: if EventType.JSON_MESSAGE in self._callbacks: str_message = ''.join(text_message) try: self._callbacks[EventType.JSON_MESSAGE]( self, json.loads(str_message)) text_message.clear() return # no need to process text handler if json handler already does the job except json.JSONDecodeError: pass if EventType.TEXT_MESSAGE in self._callbacks: self._callbacks[EventType.TEXT_MESSAGE](self, ''.join(text_message)) text_message.clear() def _handle_binary_message(self, event: BytesMessage, binary_message: bytearray) -> None: binary_message.extend(event.data) if event.message_finished: if EventType.BINARY_MESSAGE in self._callbacks: self._callbacks[EventType.BINARY_MESSAGE](self, binary_message) binary_message.clear() def _run(self) -> None: reject_data = bytearray() reject_status_code = 400 reject_headers = [] text_message = [] binary_message = bytearray() while self._running: data = self._sock.recv(self.receive_bytes) if not data: data = None self._ws.receive_data(data) for event in self._ws.events(): if isinstance(event, AcceptConnection): self._handle_accept(event) elif isinstance(event, RejectConnection): if not event.has_body: self._handle_reject(event) else: reject_status_code = event.status_code reject_headers = event.headers elif isinstance(event, RejectData): self._handle_reject_data(event, reject_data, reject_status_code, reject_headers) elif isinstance(event, CloseConnection): self._handle_close(event) elif isinstance(event, Ping): self._handle_ping(event) elif isinstance(event, Pong): self._handle_pong(event) elif isinstance(event, TextMessage): self._handle_text_or_json_message(event, text_message) elif isinstance(event, BytesMessage): self._handle_binary_message(event, binary_message) else: print('unknown event', event) self._sock.close() def ping(self, data: bytes = b'hello') -> None: self._handshake_finished.get() if not isinstance(data, bytes): raise TypeError('data must be bytes') self._sock.sendall(self._ws.send(Ping(data))) def _send_data(self, data: AnyStr) -> None: if isinstance(data, str): io_object = io.StringIO(data) else: io_object = io.BytesIO(data) with io_object as f: chunk = f.read(self.buffer_size) while chunk: if len(chunk) < self.buffer_size: self._sock.sendall( self._ws.send(Message(data, message_finished=True))) break else: self._sock.sendall( self._ws.send(Message(data, message_finished=False))) chunk = f.read(self.buffer_size) def send(self, data: AnyStr) -> None: self._handshake_finished.get() if not isinstance(data, (bytes, str)): raise TypeError('data must be bytes or string') self._send_data(data) def send_json(self, data: Any) -> None: self.send(json.dumps(data)) def _close_ws_connection(self): close_data = self._ws.send( CloseConnection(code=1000, reason='nothing more to do')) self._sock.sendall(close_data) def close(self) -> None: self._handshake_finished.get() if self._ws.state is ConnectionState.OPEN: self._close_ws_connection() # don't forget to join the run greenlet, if not, you will have some surprises with your event handlers! self._green.join() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close()
class BaseServer(ABC): bytes_to_receive: int = 65535 buffer_size: int = io.DEFAULT_BUFFER_SIZE # noinspection PyTypeChecker def __init__(self, host: str, port: int): self._check_init_arguments(host, port) self._host = host self._port = port self._server: StreamServer = None self._ws = WSConnection(ConnectionType.SERVER) self._client: socket = None # client socket provided by the StreamServer self._running = True @staticmethod def _check_ws_headers(headers: Headers) -> None: if headers is None: return error_message = 'headers must of a list of tuples of the form [(bytes, bytes), ..]' if not isinstance(headers, list): raise TypeError(error_message) try: for key, value in headers: if not isinstance(key, bytes) or not isinstance(value, bytes): raise TypeError(error_message) except ValueError: # in case it is not a list of tuples raise TypeError(error_message) @abstractmethod def handle_request(self, request: Request) -> None: pass def accept_request(self, extra_headers: Headers = None, sub_protocol: str = None) -> None: self._check_ws_headers(extra_headers) if sub_protocol is not None and not isinstance(sub_protocol, str): raise TypeError('sub_protocol must be a string') extra_headers = extra_headers if extra_headers else [] self._client.sendall( self._ws.send( AcceptConnection(extra_headers=extra_headers, subprotocol=sub_protocol))) def reject_request(self, status_code: int = 400, reason: str = None) -> None: if not isinstance(status_code, int): raise TypeError('status_code must be an integer') if reason is not None and not isinstance(reason, str): raise TypeError('reason must be a string') if not reason: self._client.sendall( self._ws.send(RejectConnection(status_code=status_code))) else: data = bytearray( self._ws.send( RejectConnection(has_body=True, headers=[(b'Content-type', b'text/txt') ]))) data.extend(self._ws.send(RejectData(reason.encode()))) self._client.sendall(bytes(data)) def close_request(self, code: int = 1000, reason: str = None) -> None: if not isinstance(code, int): raise TypeError('code must be an integer') if not isinstance(reason, str): raise TypeError('reason must be a string') self._client.sendall(self._ws.send(CloseConnection(code, reason))) def _handle_close_event(self, event: CloseConnection) -> None: if self._ws.state is ConnectionState.REMOTE_CLOSING: self._client.sendall(self._ws.send(event.response())) def _handle_ping(self, event: Ping) -> None: self._client.sendall(self._ws.send(event.response())) @abstractmethod def receive_text(self, data: str) -> None: pass @abstractmethod def receive_json(self, data: Any) -> None: pass @abstractmethod def receive_bytes(self, data: bytes) -> None: pass @abstractmethod def handle_pong(self, data: bytes) -> None: pass def _send_data(self, data: AnyStr) -> None: if isinstance(data, str): io_object = io.StringIO(data) else: io_object = io.BytesIO(data) with io_object as f: chunk = f.read(self.buffer_size) while chunk: if len(chunk) < self.buffer_size: self._client.sendall( self._ws.send(Message(data, message_finished=True))) break else: self._client.sendall( self._ws.send(Message(data, message_finished=False))) chunk = f.read(self.buffer_size) def ping(self, data: bytes = b'hello') -> None: if not isinstance(data, bytes): raise TypeError('data must be bytes') self._client.sendall(self._ws.send(Ping(data))) def send(self, data: AnyStr) -> None: if not isinstance(data, (bytes, str)): raise TypeError('data must be either a string or binary data') self._send_data(data) def send_json(self, data: Any) -> None: self.send(json.dumps(data)) @staticmethod def _check_init_arguments(host: str, port: int) -> None: if not isinstance(host, str): raise TypeError('host must be a string') error_message = 'custom_port must a positive integer' if not isinstance(port, int): raise TypeError(error_message) if port < 0: raise TypeError(error_message) def _handler(self, client: socket, address: Tuple[str, int]) -> None: self._client = client text_message = [] binary_message = bytearray() while self._running: data = client.recv(self.bytes_to_receive) self._ws.receive_data(data) for event in self._ws.events(): if isinstance(event, Request): self.handle_request(event) elif isinstance(event, CloseConnection): self._handle_close_event(event) self._running = False elif isinstance(event, Ping): self._handle_ping(event) elif isinstance(event, Pong): self.handle_pong(event.payload) elif isinstance(event, TextMessage): text_message.append(event.data) if event.message_finished: str_data = ''.join(text_message) try: self.receive_json(json.loads(str_data)) except json.JSONDecodeError: self.receive_text(str_data) text_message.clear() elif isinstance(event, BytesMessage): binary_message.extend(event.data) if event.message_finished: self.receive_bytes(bytes(binary_message)) binary_message.clear() else: print('unknown event:', event) def run(self, backlog: int = 256, spawn: str = 'default', **kwargs) -> None: self._server = StreamServer((self._host, self._port), self._handler, backlog=backlog, spawn=spawn, **kwargs) self._server.serve_forever() def close(self) -> None: if self._server is not None: self._server.close()
def handle_connection(stream): ''' Handle a connection. The server operates a request/response cycle, so it performs a synchronous loop: 1) Read data from network into wsproto 2) Get next wsproto event 3) Handle event 4) Send data from wsproto to network :param stream: a socket stream ''' ws = WSConnection(ConnectionType.SERVER) # events is a generator that yields websocket event objects. Usually you # would say `for event in ws.events()`, but the synchronous nature of this # server requires us to use next(event) instead so that we can interleave # the network I/O. events = ws.events() running = True while running: # 1) Read data from network in_data = stream.recv(RECEIVE_BYTES) print('Received {} bytes'.format(len(in_data))) ws.receive_data(in_data) # 2) Get next wsproto event try: event = next(events) except StopIteration: print('Client connection dropped unexpectedly') return # 3) Handle event if isinstance(event, Request): # Negotiate new WebSocket connection print('Accepting WebSocket upgrade') out_data = ws.send(AcceptConnection()) elif isinstance(event, CloseConnection): # Print log message and break out print('Connection closed: code={}/{} reason={}'.format( event.code.value, event.code.name, event.reason)) out_data = ws.send(event.response()) running = False elif isinstance(event, TextMessage): # Reverse text and send it back to wsproto print('Received request and sending response') out_data = ws.send(Message(data=event.data[::-1])) elif isinstance(event, Ping): # wsproto handles ping events for you by placing a pong frame in # the outgoing buffer. You should not call pong() unless you want to # send an unsolicited pong frame. print('Received ping and sending pong') out_data = ws.send(event.response()) else: print('Unknown event: {!r}'.format(event)) # 4) Send data from wsproto to network print('Sending {} bytes'.format(len(out_data))) stream.send(out_data)