def create_connection(self): server = WSConnection(SERVER) client = WSConnection(CLIENT, host='localhost', resource='foo') server.receive_bytes(client.bytes_to_send()) event = next(server.events()) assert isinstance(event, ConnectionRequested) server.accept(event) client.receive_bytes(server.bytes_to_send()) assert isinstance(next(client.events()), ConnectionEstablished) return client, server
def test_accept_subprotocol(self): test_host = 'frob.nitz' test_path = '/fnord' ws = WSConnection(SERVER) nonce = bytes(random.getrandbits(8) for x in range(0, 16)) nonce = base64.b64encode(nonce) request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' request += b'Host: ' + test_host.encode('ascii') + b'\r\n' request += b'Connection: Upgrade\r\n' request += b'Upgrade: WebSocket\r\n' request += b'Sec-WebSocket-Version: 13\r\n' request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' request += b'Sec-WebSocket-Protocol: one, two\r\n' request += b'\r\n' ws.receive_bytes(request) event = next(ws.events()) assert isinstance(event, ConnectionRequested) assert event.proposed_subprotocols == ['one', 'two'] ws.accept(event, 'two') data = ws.bytes_to_send() response, headers = data.split(b'\r\n', 1) version, code, reason = response.split(b' ') headers = parse_headers(headers) assert int(code) == 101 assert headers['sec-websocket-protocol'] == 'two'
def test_accept_wrong_subprotocol(self): test_host = 'frob.nitz' test_path = '/fnord' ws = WSConnection(SERVER) nonce = bytes(random.getrandbits(8) for x in range(0, 16)) nonce = base64.b64encode(nonce) request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' request += b'Host: ' + test_host.encode('ascii') + b'\r\n' request += b'Connection: Upgrade\r\n' request += b'Upgrade: WebSocket\r\n' request += b'Sec-WebSocket-Version: 13\r\n' request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' request += b'Sec-WebSocket-Protocol: one, two\r\n' request += b'\r\n' ws.receive_bytes(request) event = next(ws.events()) assert isinstance(event, ConnectionRequested) assert event.proposed_subprotocols == ['one', 'two'] with pytest.raises(ValueError): ws.accept(event, 'three')
def test_unwanted_extension_negotiation(self): test_host = 'frob.nitz' test_path = '/fnord' ext = FakeExtension(accept_response=False) ws = WSConnection(SERVER, extensions=[ext]) nonce = bytes(random.getrandbits(8) for x in range(0, 16)) nonce = base64.b64encode(nonce) request = b"GET " + test_path.encode('ascii') + b" HTTP/1.1\r\n" request += b'Host: ' + test_host.encode('ascii') + b'\r\n' request += b'Connection: Upgrade\r\n' request += b'Upgrade: WebSocket\r\n' request += b'Sec-WebSocket-Version: 13\r\n' request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' request += b'Sec-WebSocket-Extensions: pretend\r\n' request += b'\r\n' ws.receive_bytes(request) event = next(ws.events()) assert isinstance(event, ConnectionRequested) ws.accept(event) data = ws.bytes_to_send() response, headers = data.split(b'\r\n', 1) version, code, reason = response.split(b' ') headers = parse_headers(headers) assert 'sec-websocket-extensions' not in headers
def test_not_an_http_request_at_all(self): ws = WSConnection(SERVER) request = b'<xml>Good god, what is this?</xml>\r\n\r\n' ws.receive_bytes(request) assert isinstance(next(ws.events()), ConnectionFailed)
def get_case_count(server): uri = urlparse(server + '/getCaseCount') connection = WSConnection(CLIENT, uri.netloc, uri.path) sock = socket.socket() sock.connect((uri.hostname, uri.port or 80)) sock.sendall(connection.bytes_to_send()) case_count = None while case_count is None: data = sock.recv(65535) connection.receive_bytes(data) data = "" for event in connection.events(): if isinstance(event, TextReceived): data += event.data if event.message_finished: case_count = json.loads(data) connection.close() try: sock.sendall(connection.bytes_to_send()) except CONNECTION_EXCEPTIONS: break sock.close() return case_count
def new_conn(sock): global count print("test_server.py received connection {}".format(count)) count += 1 ws = WSConnection(SERVER, extensions=[PerMessageDeflate()]) closed = False while not closed: try: data = sock.recv(65535) except socket.error: data = None ws.receive_bytes(data or None) for event in ws.events(): if isinstance(event, ConnectionRequested): ws.accept(event) elif isinstance(event, DataReceived): ws.send_data(event.data, event.message_finished) elif isinstance(event, ConnectionClosed): closed = True if not data: closed = True try: data = ws.bytes_to_send() sock.sendall(data) except socket.error: closed = True sock.close()
def new_conn(reader, writer): ws = WSConnection(SERVER, extensions=[PerMessageDeflate()]) closed = False while not closed: try: data = yield from reader.read(65535) except ConnectionError: data = None ws.receive_bytes(data or None) for event in ws.events(): if isinstance(event, ConnectionRequested): ws.accept(event) elif isinstance(event, DataReceived): ws.send_data(event.data, event.final) elif isinstance(event, ConnectionClosed): closed = True if data is None: break try: data = ws.bytes_to_send() writer.write(data) yield from writer.drain() except (ConnectionError, OSError): closed = True if closed: break writer.close()
def new_conn(reader, writer): global count print("test_server.py received connection {}".format(count)) count += 1 ws = WSConnection(SERVER, extensions=[PerMessageDeflate()]) closed = False while not closed: try: data = yield from reader.read(65535) except ConnectionError: data = None ws.receive_bytes(data or None) for event in ws.events(): if isinstance(event, ConnectionRequested): ws.accept(event) elif isinstance(event, DataReceived): ws.send_data(event.data, event.message_finished) elif isinstance(event, ConnectionClosed): closed = True if not data: closed = True try: data = ws.bytes_to_send() writer.write(data) yield from writer.drain() except (ConnectionError, OSError): closed = True writer.close()
def get_case_count(server): uri = urlparse(server + '/getCaseCount') connection = WSConnection(CLIENT, uri.netloc, uri.path) reader, writer = yield from asyncio.open_connection( uri.hostname, uri.port or 80) writer.write(connection.bytes_to_send()) case_count = None while case_count is None: data = yield from reader.read(65535) connection.receive_bytes(data) data = "" for event in connection.events(): if isinstance(event, TextReceived): data += event.data if event.message_finished: case_count = json.loads(data) connection.close() try: writer.write(connection.bytes_to_send()) yield from writer.drain() except (ConnectionError, OSError): break return case_count
def test_correct_request(self): test_host = 'frob.nitz' test_path = '/fnord' ws = WSConnection(SERVER) nonce = bytes(random.getrandbits(8) for x in range(0, 16)) nonce = base64.b64encode(nonce) request = b"GET " + test_path.encode('ascii') + b" HTTP/1.1\r\n" request += b'Host: ' + test_host.encode('ascii') + b'\r\n' request += b'Connection: Upgrade\r\n' request += b'Upgrade: WebSocket\r\n' request += b'Sec-WebSocket-Version: 13\r\n' request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' request += b'\r\n' ws.receive_bytes(request) event = next(ws.events()) assert isinstance(event, ConnectionRequested) ws.accept(event) data = ws.bytes_to_send() response, headers = data.split(b'\r\n', 1) version, code, reason = response.split(b' ') headers = parse_headers(headers) accept_token = ws._generate_accept_token(nonce) assert int(code) == 101 assert headers['connection'].lower() == 'upgrade' assert headers['upgrade'].lower() == 'websocket' assert headers['sec-websocket-accept'] == accept_token.decode('ascii')
def test_missing_key(self): test_host = 'frob.nitz' test_path = '/fnord' ws = WSConnection(SERVER) request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' request += b'Host: ' + test_host.encode('ascii') + b'\r\n' request += b'Connection: Upgrade\r\n' request += b'Upgrade: WebSocket\r\n' request += b'Sec-WebSocket-Version: 13\r\n' request += b'\r\n' ws.receive_bytes(request) event = next(ws.events()) assert isinstance(event, ConnectionFailed)
async def ws_adapter(in_q, out_q, client, _): """A simple, queue-based Curio-Sans-IO websocket bridge.""" client.setsockopt(IPPROTO_TCP, TCP_NODELAY, 1) wsconn = WSConnection(SERVER) closed = False while not closed: wstask = await spawn(client.recv, 65535) outqtask = await spawn(out_q.get) async with TaskGroup([wstask, outqtask]) as g: task = await g.next_done() result = await task.join() await g.cancel_remaining() if task is wstask: wsconn.receive_bytes(result) for event in wsconn.events(): cl = event.__class__ if cl in DATA_TYPES: await in_q.put(event.data) elif cl is ConnectionRequested: # Auto accept. Maybe consult the handler? wsconn.accept(event) elif cl is ConnectionClosed: # The client has closed the connection. await in_q.put(None) closed = True else: print(event) await client.sendall(wsconn.bytes_to_send()) else: # We got something from the out queue. if result is None: # Terminate the connection. print("Closing the connection.") wsconn.close() closed = True else: wsconn.send_data(result) payload = wsconn.bytes_to_send() await client.sendall(payload) print("Bridge done.")
def test_missing_version(self): test_host = 'frob.nitz' test_path = '/fnord' ws = WSConnection(SERVER) nonce = bytes(random.getrandbits(8) for x in range(0, 16)) nonce = base64.b64encode(nonce) request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' request += b'Host: ' + test_host.encode('ascii') + b'\r\n' request += b'Connection: Upgrade\r\n' request += b'Upgrade: WebSocket\r\n' request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' request += b'\r\n' ws.receive_bytes(request) event = next(ws.events()) assert isinstance(event, ConnectionFailed)
def test_correct_request_expanded_connection_header(self): test_host = 'frob.nitz' test_path = '/fnord' ws = WSConnection(SERVER) nonce = bytes(random.getrandbits(8) for x in range(0, 16)) nonce = base64.b64encode(nonce) request = b"GET " + test_path.encode('ascii') + b" HTTP/1.1\r\n" request += b'Host: ' + test_host.encode('ascii') + b'\r\n' request += b'Connection: keep-alive, Upgrade\r\n' request += b'Upgrade: WebSocket\r\n' request += b'Sec-WebSocket-Version: 13\r\n' request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' request += b'\r\n' ws.receive_bytes(request) event = next(ws.events()) assert isinstance(event, ConnectionRequested)
def test_frame_protocol_somehow_loses_its_mind(self): class FailFrame(object): opcode = object() class DoomProtocol(object): def receive_bytes(self, data): return None def received_frames(self): return [FailFrame()] connection = WSConnection(CLIENT, host='localhost', resource='foo') connection._proto = DoomProtocol() connection._state = ConnectionState.OPEN connection.bytes_to_send() connection.receive_bytes(b'') with pytest.raises(StopIteration): next(connection.events()) assert not connection.bytes_to_send()
def update_reports(server, agent): uri = urlparse(server + '/updateReports?agent=%s' % agent) connection = WSConnection(CLIENT, uri.netloc, '%s?%s' % (uri.path, uri.query)) sock = socket.socket() sock.connect((uri.hostname, uri.port or 80)) sock.sendall(connection.bytes_to_send()) closed = False while not closed: data = sock.recv(65535) connection.receive_bytes(data) for event in connection.events(): if isinstance(event, ConnectionEstablished): connection.close() sock.sendall(connection.bytes_to_send()) try: sock.close() except CONNECTION_EXCEPTIONS: pass finally: closed = True
def test_data_events(self, text, payload, full_message, full_frame): if text: opcode = 0x01 encoded_payload = payload.encode('utf8') else: opcode = 0x02 encoded_payload = payload if full_message: opcode = bytearray([opcode | 0x80]) else: opcode = bytearray([opcode]) if full_frame: length = bytearray([len(encoded_payload)]) else: length = bytearray([len(encoded_payload) + 100]) frame = opcode + length + encoded_payload connection = WSConnection(CLIENT, host='localhost', resource='foo') connection._proto = FrameProtocol(True, []) connection._state = ConnectionState.OPEN connection.bytes_to_send() connection.receive_bytes(frame) event = next(connection.events()) if text: assert isinstance(event, TextReceived) else: assert isinstance(event, BytesReceived) assert event.data == payload assert event.frame_finished is full_frame assert event.message_finished is full_message assert not connection.bytes_to_send()
def test_extension_negotiation_with_our_parameters(self): test_host = 'frob.nitz' test_path = '/fnord' offered_params = 'parameter1=value3; parameter2=value4' ext_params = 'parameter1=value1; parameter2=value2' ext = FakeExtension(accept_response=ext_params) ws = WSConnection(SERVER, extensions=[ext]) nonce = bytes(random.getrandbits(8) for x in range(0, 16)) nonce = base64.b64encode(nonce) request = b"GET " + test_path.encode('ascii') + b" HTTP/1.1\r\n" request += b'Host: ' + test_host.encode('ascii') + b'\r\n' request += b'Connection: Upgrade\r\n' request += b'Upgrade: WebSocket\r\n' request += b'Sec-WebSocket-Version: 13\r\n' request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' request += b'Sec-WebSocket-Extensions: ' + \ ext.name.encode('ascii') + b'; ' + \ offered_params.encode('ascii') + b'\r\n' request += b'\r\n' ws.receive_bytes(request) event = next(ws.events()) assert isinstance(event, ConnectionRequested) ws.accept(event) data = ws.bytes_to_send() response, headers = data.split(b'\r\n', 1) version, code, reason = response.split(b' ') headers = parse_headers(headers) assert ext.offered == '%s; %s' % (ext.name, offered_params) assert headers['sec-websocket-extensions'] == \ '%s; %s' % (ext.name, ext_params)
def update_reports(server, agent): uri = urlparse(server + '/updateReports?agent=%s' % agent) connection = WSConnection(CLIENT, uri.netloc, '%s?%s' % (uri.path, uri.query)) reader, writer = yield from asyncio.open_connection( uri.hostname, uri.port or 80) writer.write(connection.bytes_to_send()) closed = False while not closed: data = yield from reader.read(65535) connection.receive_bytes(data) for event in connection.events(): if isinstance(event, ConnectionEstablished): connection.close() writer.write(connection.bytes_to_send()) try: yield from writer.drain() writer.close() except (ConnectionError, OSError): pass finally: closed = True
class ClientWebsocket(object): """ Represents a ClientWebsocket. """ def __init__(self, address: Tuple[str, str, bool, str], *, reconnecting: bool = True): """ :param type_: The :class:`wsproto.connection.ConnectionType` for this websocket. :param address: A 4-tuple of (host, port, ssl, endpoint). :param endpoint: The endpoint to open the connection to. :param reconnecting: If this websocket reconnects automatically. Only useful on the client. """ # these are all used to construct the state object self._address = address self.state = None # type: WSConnection self._ready = False self._reconnecting = reconnecting self.sock = None # type: trio.socket.Socket @property def closed(self) -> bool: """ :return: If this websocket is closed. """ return self.state.closed def _create_ssl_ctx(self, sslp): if isinstance(sslp, ssl.SSLContext): return sslp ca = sslp.get('ca') capath = sslp.get('capath') hasnoca = ca is None and capath is None ctx = ssl.create_default_context(cafile=ca, capath=capath) ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True) ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED if 'cert' in sslp: ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key')) if 'cipher' in sslp: ctx.set_ciphers(sslp['cipher']) ctx.options |= ssl.OP_NO_SSLv2 ctx.options |= ssl.OP_NO_SSLv3 return ctx async def open_connection(self): """ Opens a connection, and performs the initial handshake. """ _ssl = self._address[2] _sock = await trio.open_tcp_stream(self._address[0], self._address[1]) if _ssl: if _ssl is True: _ssl = {} _ssl = self._create_ssl_ctx(_ssl) _sock = trio.ssl.SSLStream(_sock, _ssl, server_hostname=self._address[0], https_compatible=True) await _sock.do_handshake() self.sock = _sock self.state = WSConnection(ConnectionType.CLIENT, host=self._address[0], resource=self._address[3]) res = self.state.bytes_to_send() await self.sock.send_all(res) async def __aiter__(self): # initiate the websocket await self.open_connection() # setup buffers buf_bytes = BytesIO() buf_text = StringIO() while self.sock is not None: data = await self.sock.receive_some(4096) self.state.receive_bytes(data) # do ping/pongs if needed to_send = self.state.bytes_to_send() if to_send: await self.sock.send_all(to_send) for event in self.state.events(): if isinstance(event, events.ConnectionEstablished): self._ready = True yield WebsocketConnectionEstablished(event) elif isinstance(event, events.ConnectionClosed): self._ready = False yield WebsocketClosed(event) if self._reconnecting: await self.open_connection() else: return elif isinstance(event, events.DataReceived): buf = buf_bytes if isinstance(event, events.BytesReceived) else buf_text buf.write(event.data) # yield events as appropriate if event.message_finished: buf.seek(0) read = buf.read() # empty buffer buf.truncate(0) buf.seek(0) typ = WebsocketBytesMessage if isinstance(event, events.BytesReceived) \ else WebsocketTextMessage yield typ(read) elif isinstance(event, events.ConnectionFailed): self._ready = False yield WebsocketConnectionFailed(event) if self._reconnecting: await self.open_connection() else: return else: raise RuntimeError("I don't understand this message", event) async def send_message(self, data: Union[str, bytes]): """ Sends a message on the websocket. :param data: The data to send. Either str or bytes. """ self.state.send_data(data, final=True) return await self.sock.send_all(self.state.bytes_to_send()) async def aclose(self, *, code: int = 1000, reason: str = "No reason", allow_reconnects: bool = False): """ Closes the websocket. :param code: The close code to use. :param reason: The close reason to use. If the websocket is marked as reconnecting: :param allow_reconnects: If the websocket can reconnect after this close. """ # do NOT reconnect if we close explicitly and don't allow reconnects if not allow_reconnects: self._reconnecting = False self.state.close(code=code, reason=reason) to_send = self.state.bytes_to_send() await self.sock.send_all(to_send) await self.sock.aclose() self.sock = None self.state.receive_bytes(None)
async def wsproto_client_demo(host, port, use_ssl): ''' Demonstrate wsproto: 0) Open TCP connection 1) Negotiate WebSocket opening handshake 2) Send a message and display response 3) Send ping and display pong 4) Negotiate WebSocket closing handshake :param stream: a socket stream ''' # 0) Open TCP connection print('[C] Connecting to {}:{}'.format(host, port)) conn = await trio.open_tcp_stream(host, port) if use_ssl: conn = upgrade_stream_to_ssl(conn, host) # 1) Negotiate WebSocket opening handshake print('[C] Opening WebSocket') ws = WSConnection(ConnectionType.CLIENT, host=host, resource='server') events = ws.events() # Because this is a client WebSocket, wsproto has automatically queued up # a handshake, and we need to send it and wait for a response. await net_send_recv(ws, conn) event = next(events) if isinstance(event, ConnectionEstablished): print('[C] WebSocket negotiation complete') else: raise Exception(f'Expected ConnectionEstablished event! Got: {event}') # 2) Send a message and display response message = "wsproto is great" * 10 print('[C] Sending message: {}'.format(message)) ws.send_data(message) await net_send_recv(ws, conn) event = next(events) if isinstance(event, TextReceived): print('[C] Received message: {}'.format(event.data)) else: raise Exception(f'Expected TextReceived event! Got: {event}') # 3) Send ping and display pong payload = b"table tennis" print('[C] Sending ping: {}'.format(payload)) ws.ping(payload) await net_send_recv(ws, conn) event = next(events) if isinstance(event, PongReceived): print('[C] Received pong: {}'.format(event.payload)) else: raise Exception(f'Expected PongReceived event! Got: {event}') # 4) Negotiate WebSocket closing handshake print('[C] Closing WebSocket') ws.close(code=1000, reason='sample reason') # After sending the closing frame, we won't get any more events. The server # should send a reply and then close the connection, so we need to receive # twice: await net_send_recv(ws, conn) await conn.aclose()
uri.netloc, '%s?%s' % (uri.path, uri.query), extensions=[PerMessageDeflate()]) sock = socket.socket() sock.connect((uri.hostname, uri.port or 80)) sock.sendall(connection.bytes_to_send()) closed = False while not closed: try: data = sock.recv(65535) except CONNECTION_EXCEPTIONS: data = None connection.receive_bytes(data or None) for event in connection.events(): if isinstance(event, DataReceived): connection.send_data(event.data, event.message_finished) elif isinstance(event, ConnectionClosed): closed = True # else: # print("??", event) if data is None: break try: data = connection.bytes_to_send() sock.sendall(data) except CONNECTION_EXCEPTIONS: closed = True break
def handle_connection(stream): ''' Handle a connection. The server operates a request/response cycle, so it performs a synchronous loop: 1) Read data from network into wsproto 2) Get next wsproto event 3) Handle event 4) Send data from wsproto to network :param stream: a socket stream ''' ws = WSConnection(ConnectionType.SERVER) # events is a generator that yields websocket event objects. Usually you # would say `for event in ws.events()`, but the synchronous nature of this # server requires us to use next(event) instead so that we can interleave # the network I/O. events = ws.events() running = True while running: # 1) Read data from network in_data = stream.recv(RECEIVE_BYTES) print('Received {} bytes'.format(len(in_data))) ws.receive_bytes(in_data) # 2) Get next wsproto event try: event = next(events) except StopIteration: print('Client connection dropped unexpectedly') return # 3) Handle event if isinstance(event, ConnectionRequested): # Negotiate new WebSocket connection print('Accepting WebSocket upgrade') ws.accept(event) elif isinstance(event, ConnectionClosed): # Print log message and break out print('Connection closed: code={}/{} reason={}'.format( event.code.value, event.code.name, event.reason)) running = False elif isinstance(event, TextReceived): # Reverse text and send it back to wsproto print('Received request and sending response') ws.send_data(event.data[::-1]) elif isinstance(event, PingReceived): # wsproto handles ping events for you by placing a pong frame in # the outgoing buffer. You should not call pong() unless you want to # send an unsolicited pong frame. print('Received ping and sending pong') else: print('Unknown event: {!r}'.format(event)) # 4) Send data from wsproto to network out_data = ws.bytes_to_send() print('Sending {} bytes'.format(len(out_data))) stream.send(out_data)
def test_h11_somehow_loses_its_mind(self): ws = WSConnection(SERVER) ws._upgrade_connection.next_event = lambda: object() ws.receive_bytes(b'') assert isinstance(next(ws.events()), ConnectionFailed)
class ClientWebsocket(object): """ Represents a ClientWebsocket. """ def __init__(self, address: Tuple[str, str, bool, str], *, reconnecting: bool = True): """ :param type_: The :class:`wsproto.connection.ConnectionType` for this websocket. :param address: A 4-tuple of (host, port, ssl, endpoint). :param endpoint: The endpoint to open the connection to. :param reconnecting: If this websocket reconnects automatically. Only useful on the client. """ # these are all used to construct the state object self._address = address self.state = None # type: WSConnection self._ready = False self._reconnecting = reconnecting self.sock = None # type: multio.SocketWrapper @property def closed(self) -> bool: """ :return: If this websocket is closed. """ return self.state.closed async def open_connection(self): """ Opens a connection, and performs the initial handshake. """ _sock = await multio.asynclib.open_connection(self._address[0], self._address[1], ssl=self._address[2]) self.sock = multio.SocketWrapper(_sock) self.state = WSConnection(ConnectionType.CLIENT, host=self._address[0], resource=self._address[3]) res = self.state.bytes_to_send() await self.sock.sendall(res) async def __aiter__(self): # initiate the websocket await self.open_connection() # setup buffers buf_bytes = BytesIO() buf_text = StringIO() while True: data = await self.sock.recv(4096) self.state.receive_bytes(data) # do ping/pongs if needed to_send = self.state.bytes_to_send() if to_send: await self.sock.sendall(to_send) for event in self.state.events(): if isinstance(event, events.ConnectionEstablished): self._ready = True yield WebsocketConnectionEstablished(event) elif isinstance(event, events.ConnectionClosed): self._ready = False yield WebsocketClosed(event) if self._reconnecting: await self.open_connection() else: return elif isinstance(event, events.DataReceived): buf = buf_bytes if isinstance( event, events.BytesReceived) else buf_text buf.write(event.data) # yield events as appropriate if event.message_finished: buf.seek(0) read = buf.read() # empty buffer buf.truncate(0) buf.seek(0) typ = WebsocketBytesMessage if isinstance(event, events.BytesReceived) \ else WebsocketTextMessage yield typ(read) elif isinstance(event, events.ConnectionFailed): self._ready = False yield WebsocketConnectionFailed(event) if self._reconnecting: await self.open_connection() else: return async def send_message(self, data: Union[str, bytes]): """ Sends a message on the websocket. :param data: The data to send. Either str or bytes. """ self.state.send_data(data, final=True) return await self.sock.sendall(self.state.bytes_to_send()) async def close(self, *, code: int = 1000, reason: str = "No reason", allow_reconnects: bool = False): """ Closes the websocket. :param code: The close code to use. :param reason: The close reason to use. If the websocket is marked as reconnecting: :param allow_reconnects: If the websocket can reconnect after this close. """ # do NOT reconnect if we close explicitly and don't allow reconnects if not allow_reconnects: self._reconnecting = False self.state.close(code=code, reason=reason) to_send = self.state.bytes_to_send() await self.sock.sendall(to_send) await self.sock.close() self.state.receive_bytes(None)
def wsproto_demo(host, port): ''' Demonstrate wsproto: 0) Open TCP connection 1) Negotiate WebSocket opening handshake 2) Send a message and display response 3) Send ping and display pong 4) Negotiate WebSocket closing handshake :param stream: a socket stream ''' # 0) Open TCP connection print('Connecting to {}:{}'.format(host, port)) conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) conn.connect((host, port)) # 1) Negotiate WebSocket opening handshake print('Opening WebSocket') ws = WSConnection(ConnectionType.CLIENT, host=host, resource='server') # events is a generator that yields websocket event objects. Usually you # would say `for event in ws.events()`, but the synchronous nature of this # client requires us to use next(event) instead so that we can interleave # the network I/O. It will raise StopIteration when it runs out of events # (i.e. needs more network data), but since this script is synchronous, we # will explicitly resume the generator whenever we have new network data. events = ws.events() # Because this is a client WebSocket, wsproto has automatically queued up # a handshake, and we need to send it and wait for a response. net_send_recv(ws, conn) event = next(events) if isinstance(event, ConnectionEstablished): print('WebSocket negotiation complete') else: raise Exception('Expected ConnectionEstablished event!') # 2) Send a message and display response message = "wsproto is great" print('Sending message: {}'.format(message)) ws.send_data(message) net_send_recv(ws, conn) event = next(events) if isinstance(event, TextReceived): print('Received message: {}'.format(event.data)) else: raise Exception('Expected TextReceived event!') # 3) Send ping and display pong payload = b"table tennis" print('Sending ping: {}'.format(payload)) ws.ping(payload) net_send_recv(ws, conn) event = next(events) if isinstance(event, PongReceived): print('Received pong: {}'.format(event.payload)) else: raise Exception('Expected PongReceived event!') # 4) Negotiate WebSocket closing handshake print('Closing WebSocket') ws.close(code=1000, reason='sample reason') # After sending the closing frame, we won't get any more events. The server # should send a reply and then close the connection, so we need to receive # twice: net_send_recv(ws, conn) conn.shutdown(socket.SHUT_WR) net_recv(ws, conn)
class WebSocketEndpoint(AbstractEndpoint): """ Implements websocket endpoints. Subprotocol negotiation is currently not supported. """ __slots__ = ('ctx', '_client', '_ws') def __init__(self, ctx: Context, client: BaseHTTPClientConnection): self.ctx = ctx self._client = client self._ws = WSConnection(ConnectionType.SERVER) def _process_ws_events(self): for event in self._ws.events(): if isinstance(event, ConnectionRequested): self._ws.accept(event) self.on_connect() elif isinstance(event, DataReceived): self.on_data(event.data) elif isinstance(event, ConnectionClosed): self.on_close() bytes_to_send = self._ws.bytes_to_send() if bytes_to_send: self._client.write(bytes_to_send) def begin_request(self, request: HTTPRequest): trailing_data = self._client.upgrade() self._ws.receive_bytes(trailing_data) self._process_ws_events() def receive_body_data(self, data: bytes) -> None: self._ws.receive_bytes(data) self._process_ws_events() def send_message(self, payload: Union[str, bytes]) -> None: """ Send a message to the client. :param payload: either a unicode string or a bytestring """ self._ws.send_data(payload) bytes_to_send = self._ws.bytes_to_send() self._client.write(bytes_to_send) def close(self) -> None: """Close the websocket.""" self._ws.close() self._process_ws_events() def on_connect(self) -> None: """Called when the websocket handshake has been done.""" def on_close(self) -> None: """Called when the connection has been closed.""" def on_data(self, data: bytes) -> None: """Called when there is new data from the peer."""