Ejemplo n.º 1
0
    def create_connection(self):
        server = WSConnection(SERVER)
        client = WSConnection(CLIENT, host='localhost', resource='foo')

        server.receive_bytes(client.bytes_to_send())
        event = next(server.events())
        assert isinstance(event, ConnectionRequested)

        server.accept(event)
        client.receive_bytes(server.bytes_to_send())
        assert isinstance(next(client.events()), ConnectionEstablished)

        return client, server
Ejemplo n.º 2
0
    def test_accept_subprotocol(self):
        test_host = 'frob.nitz'
        test_path = '/fnord'

        ws = WSConnection(SERVER)

        nonce = bytes(random.getrandbits(8) for x in range(0, 16))
        nonce = base64.b64encode(nonce)

        request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n'
        request += b'Host: ' + test_host.encode('ascii') + b'\r\n'
        request += b'Connection: Upgrade\r\n'
        request += b'Upgrade: WebSocket\r\n'
        request += b'Sec-WebSocket-Version: 13\r\n'
        request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n'
        request += b'Sec-WebSocket-Protocol: one, two\r\n'
        request += b'\r\n'

        ws.receive_bytes(request)
        event = next(ws.events())
        assert isinstance(event, ConnectionRequested)
        assert event.proposed_subprotocols == ['one', 'two']

        ws.accept(event, 'two')

        data = ws.bytes_to_send()
        response, headers = data.split(b'\r\n', 1)
        version, code, reason = response.split(b' ')
        headers = parse_headers(headers)

        assert int(code) == 101
        assert headers['sec-websocket-protocol'] == 'two'
Ejemplo n.º 3
0
    def test_accept_wrong_subprotocol(self):
        test_host = 'frob.nitz'
        test_path = '/fnord'

        ws = WSConnection(SERVER)

        nonce = bytes(random.getrandbits(8) for x in range(0, 16))
        nonce = base64.b64encode(nonce)

        request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n'
        request += b'Host: ' + test_host.encode('ascii') + b'\r\n'
        request += b'Connection: Upgrade\r\n'
        request += b'Upgrade: WebSocket\r\n'
        request += b'Sec-WebSocket-Version: 13\r\n'
        request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n'
        request += b'Sec-WebSocket-Protocol: one, two\r\n'
        request += b'\r\n'

        ws.receive_bytes(request)
        event = next(ws.events())
        assert isinstance(event, ConnectionRequested)
        assert event.proposed_subprotocols == ['one', 'two']

        with pytest.raises(ValueError):
            ws.accept(event, 'three')
Ejemplo n.º 4
0
    def test_unwanted_extension_negotiation(self):
        test_host = 'frob.nitz'
        test_path = '/fnord'
        ext = FakeExtension(accept_response=False)

        ws = WSConnection(SERVER, extensions=[ext])

        nonce = bytes(random.getrandbits(8) for x in range(0, 16))
        nonce = base64.b64encode(nonce)

        request = b"GET " + test_path.encode('ascii') + b" HTTP/1.1\r\n"
        request += b'Host: ' + test_host.encode('ascii') + b'\r\n'
        request += b'Connection: Upgrade\r\n'
        request += b'Upgrade: WebSocket\r\n'
        request += b'Sec-WebSocket-Version: 13\r\n'
        request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n'
        request += b'Sec-WebSocket-Extensions: pretend\r\n'
        request += b'\r\n'

        ws.receive_bytes(request)
        event = next(ws.events())
        assert isinstance(event, ConnectionRequested)
        ws.accept(event)

        data = ws.bytes_to_send()
        response, headers = data.split(b'\r\n', 1)
        version, code, reason = response.split(b' ')
        headers = parse_headers(headers)

        assert 'sec-websocket-extensions' not in headers
Ejemplo n.º 5
0
    def test_not_an_http_request_at_all(self):
        ws = WSConnection(SERVER)

        request = b'<xml>Good god, what is this?</xml>\r\n\r\n'

        ws.receive_bytes(request)
        assert isinstance(next(ws.events()), ConnectionFailed)
Ejemplo n.º 6
0
def get_case_count(server):
    uri = urlparse(server + '/getCaseCount')
    connection = WSConnection(CLIENT, uri.netloc, uri.path)
    sock = socket.socket()
    sock.connect((uri.hostname, uri.port or 80))

    sock.sendall(connection.bytes_to_send())

    case_count = None
    while case_count is None:
        data = sock.recv(65535)
        connection.receive_bytes(data)
        data = ""
        for event in connection.events():
            if isinstance(event, TextReceived):
                data += event.data
                if event.message_finished:
                    case_count = json.loads(data)
                    connection.close()
            try:
                sock.sendall(connection.bytes_to_send())
            except CONNECTION_EXCEPTIONS:
                break

    sock.close()
    return case_count
Ejemplo n.º 7
0
def new_conn(sock):
    global count
    print("test_server.py received connection {}".format(count))
    count += 1
    ws = WSConnection(SERVER, extensions=[PerMessageDeflate()])
    closed = False
    while not closed:
        try:
            data = sock.recv(65535)
        except socket.error:
            data = None

        ws.receive_bytes(data or None)

        for event in ws.events():
            if isinstance(event, ConnectionRequested):
                ws.accept(event)
            elif isinstance(event, DataReceived):
                ws.send_data(event.data, event.message_finished)
            elif isinstance(event, ConnectionClosed):
                closed = True

        if not data:
            closed = True

        try:
            data = ws.bytes_to_send()
            sock.sendall(data)
        except socket.error:
            closed = True

    sock.close()
Ejemplo n.º 8
0
def new_conn(reader, writer):
    ws = WSConnection(SERVER, extensions=[PerMessageDeflate()])
    closed = False
    while not closed:
        try:
            data = yield from reader.read(65535)
        except ConnectionError:
            data = None

        ws.receive_bytes(data or None)

        for event in ws.events():
            if isinstance(event, ConnectionRequested):
                ws.accept(event)
            elif isinstance(event, DataReceived):
                ws.send_data(event.data, event.final)
            elif isinstance(event, ConnectionClosed):
                closed = True
            if data is None:
                break

            try:
                data = ws.bytes_to_send()
                writer.write(data)
                yield from writer.drain()
            except (ConnectionError, OSError):
                closed = True

            if closed:
                break

    writer.close()
Ejemplo n.º 9
0
def new_conn(reader, writer):
    global count
    print("test_server.py received connection {}".format(count))
    count += 1
    ws = WSConnection(SERVER, extensions=[PerMessageDeflate()])
    closed = False
    while not closed:
        try:
            data = yield from reader.read(65535)
        except ConnectionError:
            data = None

        ws.receive_bytes(data or None)

        for event in ws.events():
            if isinstance(event, ConnectionRequested):
                ws.accept(event)
            elif isinstance(event, DataReceived):
                ws.send_data(event.data, event.message_finished)
            elif isinstance(event, ConnectionClosed):
                closed = True

        if not data:
            closed = True

        try:
            data = ws.bytes_to_send()
            writer.write(data)
            yield from writer.drain()
        except (ConnectionError, OSError):
            closed = True

    writer.close()
Ejemplo n.º 10
0
def get_case_count(server):
    uri = urlparse(server + '/getCaseCount')
    connection = WSConnection(CLIENT, uri.netloc, uri.path)
    reader, writer = yield from asyncio.open_connection(
        uri.hostname, uri.port or 80)

    writer.write(connection.bytes_to_send())

    case_count = None
    while case_count is None:
        data = yield from reader.read(65535)
        connection.receive_bytes(data)
        data = ""
        for event in connection.events():
            if isinstance(event, TextReceived):
                data += event.data
                if event.message_finished:
                    case_count = json.loads(data)
                    connection.close()
            try:
                writer.write(connection.bytes_to_send())
                yield from writer.drain()
            except (ConnectionError, OSError):
                break

    return case_count
Ejemplo n.º 11
0
    def test_correct_request(self):
        test_host = 'frob.nitz'
        test_path = '/fnord'

        ws = WSConnection(SERVER)

        nonce = bytes(random.getrandbits(8) for x in range(0, 16))
        nonce = base64.b64encode(nonce)

        request = b"GET " + test_path.encode('ascii') + b" HTTP/1.1\r\n"
        request += b'Host: ' + test_host.encode('ascii') + b'\r\n'
        request += b'Connection: Upgrade\r\n'
        request += b'Upgrade: WebSocket\r\n'
        request += b'Sec-WebSocket-Version: 13\r\n'
        request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n'
        request += b'\r\n'

        ws.receive_bytes(request)
        event = next(ws.events())
        assert isinstance(event, ConnectionRequested)
        ws.accept(event)

        data = ws.bytes_to_send()
        response, headers = data.split(b'\r\n', 1)
        version, code, reason = response.split(b' ')
        headers = parse_headers(headers)

        accept_token = ws._generate_accept_token(nonce)

        assert int(code) == 101
        assert headers['connection'].lower() == 'upgrade'
        assert headers['upgrade'].lower() == 'websocket'
        assert headers['sec-websocket-accept'] == accept_token.decode('ascii')
Ejemplo n.º 12
0
    def test_missing_key(self):
        test_host = 'frob.nitz'
        test_path = '/fnord'

        ws = WSConnection(SERVER)

        request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n'
        request += b'Host: ' + test_host.encode('ascii') + b'\r\n'
        request += b'Connection: Upgrade\r\n'
        request += b'Upgrade: WebSocket\r\n'
        request += b'Sec-WebSocket-Version: 13\r\n'
        request += b'\r\n'

        ws.receive_bytes(request)
        event = next(ws.events())
        assert isinstance(event, ConnectionFailed)
Ejemplo n.º 13
0
async def ws_adapter(in_q, out_q, client, _):
    """A simple, queue-based Curio-Sans-IO websocket bridge."""
    client.setsockopt(IPPROTO_TCP, TCP_NODELAY, 1)
    wsconn = WSConnection(SERVER)
    closed = False

    while not closed:
        wstask = await spawn(client.recv, 65535)
        outqtask = await spawn(out_q.get)

        async with TaskGroup([wstask, outqtask]) as g:
            task = await g.next_done()
            result = await task.join()
            await g.cancel_remaining()

        if task is wstask:
            wsconn.receive_bytes(result)

            for event in wsconn.events():
                cl = event.__class__
                if cl in DATA_TYPES:
                    await in_q.put(event.data)
                elif cl is ConnectionRequested:
                    # Auto accept. Maybe consult the handler?
                    wsconn.accept(event)
                elif cl is ConnectionClosed:
                    # The client has closed the connection.
                    await in_q.put(None)
                    closed = True
                else:
                    print(event)
            await client.sendall(wsconn.bytes_to_send())
        else:
            # We got something from the out queue.
            if result is None:
                # Terminate the connection.
                print("Closing the connection.")
                wsconn.close()
                closed = True
            else:
                wsconn.send_data(result)
            payload = wsconn.bytes_to_send()
            await client.sendall(payload)
    print("Bridge done.")
Ejemplo n.º 14
0
    def test_missing_version(self):
        test_host = 'frob.nitz'
        test_path = '/fnord'

        ws = WSConnection(SERVER)

        nonce = bytes(random.getrandbits(8) for x in range(0, 16))
        nonce = base64.b64encode(nonce)

        request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n'
        request += b'Host: ' + test_host.encode('ascii') + b'\r\n'
        request += b'Connection: Upgrade\r\n'
        request += b'Upgrade: WebSocket\r\n'
        request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n'
        request += b'\r\n'

        ws.receive_bytes(request)
        event = next(ws.events())
        assert isinstance(event, ConnectionFailed)
Ejemplo n.º 15
0
    def test_correct_request_expanded_connection_header(self):
        test_host = 'frob.nitz'
        test_path = '/fnord'

        ws = WSConnection(SERVER)

        nonce = bytes(random.getrandbits(8) for x in range(0, 16))
        nonce = base64.b64encode(nonce)

        request = b"GET " + test_path.encode('ascii') + b" HTTP/1.1\r\n"
        request += b'Host: ' + test_host.encode('ascii') + b'\r\n'
        request += b'Connection: keep-alive, Upgrade\r\n'
        request += b'Upgrade: WebSocket\r\n'
        request += b'Sec-WebSocket-Version: 13\r\n'
        request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n'
        request += b'\r\n'

        ws.receive_bytes(request)
        event = next(ws.events())
        assert isinstance(event, ConnectionRequested)
Ejemplo n.º 16
0
    def test_frame_protocol_somehow_loses_its_mind(self):
        class FailFrame(object):
            opcode = object()

        class DoomProtocol(object):
            def receive_bytes(self, data):
                return None

            def received_frames(self):
                return [FailFrame()]

        connection = WSConnection(CLIENT, host='localhost', resource='foo')
        connection._proto = DoomProtocol()
        connection._state = ConnectionState.OPEN
        connection.bytes_to_send()

        connection.receive_bytes(b'')
        with pytest.raises(StopIteration):
            next(connection.events())
        assert not connection.bytes_to_send()
Ejemplo n.º 17
0
def update_reports(server, agent):
    uri = urlparse(server + '/updateReports?agent=%s' % agent)
    connection = WSConnection(CLIENT, uri.netloc,
                              '%s?%s' % (uri.path, uri.query))
    sock = socket.socket()
    sock.connect((uri.hostname, uri.port or 80))

    sock.sendall(connection.bytes_to_send())
    closed = False

    while not closed:
        data = sock.recv(65535)
        connection.receive_bytes(data)
        for event in connection.events():
            if isinstance(event, ConnectionEstablished):
                connection.close()
                sock.sendall(connection.bytes_to_send())
                try:
                    sock.close()
                except CONNECTION_EXCEPTIONS:
                    pass
                finally:
                    closed = True
Ejemplo n.º 18
0
    def test_data_events(self, text, payload, full_message, full_frame):
        if text:
            opcode = 0x01
            encoded_payload = payload.encode('utf8')
        else:
            opcode = 0x02
            encoded_payload = payload

        if full_message:
            opcode = bytearray([opcode | 0x80])
        else:
            opcode = bytearray([opcode])

        if full_frame:
            length = bytearray([len(encoded_payload)])
        else:
            length = bytearray([len(encoded_payload) + 100])

        frame = opcode + length + encoded_payload

        connection = WSConnection(CLIENT, host='localhost', resource='foo')
        connection._proto = FrameProtocol(True, [])
        connection._state = ConnectionState.OPEN
        connection.bytes_to_send()

        connection.receive_bytes(frame)
        event = next(connection.events())
        if text:
            assert isinstance(event, TextReceived)
        else:
            assert isinstance(event, BytesReceived)
        assert event.data == payload
        assert event.frame_finished is full_frame
        assert event.message_finished is full_message

        assert not connection.bytes_to_send()
Ejemplo n.º 19
0
    def test_extension_negotiation_with_our_parameters(self):
        test_host = 'frob.nitz'
        test_path = '/fnord'
        offered_params = 'parameter1=value3; parameter2=value4'
        ext_params = 'parameter1=value1; parameter2=value2'
        ext = FakeExtension(accept_response=ext_params)

        ws = WSConnection(SERVER, extensions=[ext])

        nonce = bytes(random.getrandbits(8) for x in range(0, 16))
        nonce = base64.b64encode(nonce)

        request = b"GET " + test_path.encode('ascii') + b" HTTP/1.1\r\n"
        request += b'Host: ' + test_host.encode('ascii') + b'\r\n'
        request += b'Connection: Upgrade\r\n'
        request += b'Upgrade: WebSocket\r\n'
        request += b'Sec-WebSocket-Version: 13\r\n'
        request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n'
        request += b'Sec-WebSocket-Extensions: ' + \
            ext.name.encode('ascii') + b'; ' + \
            offered_params.encode('ascii') + b'\r\n'
        request += b'\r\n'

        ws.receive_bytes(request)
        event = next(ws.events())
        assert isinstance(event, ConnectionRequested)
        ws.accept(event)

        data = ws.bytes_to_send()
        response, headers = data.split(b'\r\n', 1)
        version, code, reason = response.split(b' ')
        headers = parse_headers(headers)

        assert ext.offered == '%s; %s' % (ext.name, offered_params)
        assert headers['sec-websocket-extensions'] == \
            '%s; %s' % (ext.name, ext_params)
Ejemplo n.º 20
0
def update_reports(server, agent):
    uri = urlparse(server + '/updateReports?agent=%s' % agent)
    connection = WSConnection(CLIENT, uri.netloc,
                              '%s?%s' % (uri.path, uri.query))
    reader, writer = yield from asyncio.open_connection(
        uri.hostname, uri.port or 80)

    writer.write(connection.bytes_to_send())
    closed = False

    while not closed:
        data = yield from reader.read(65535)
        connection.receive_bytes(data)
        for event in connection.events():
            if isinstance(event, ConnectionEstablished):
                connection.close()
                writer.write(connection.bytes_to_send())
                try:
                    yield from writer.drain()
                    writer.close()
                except (ConnectionError, OSError):
                    pass
                finally:
                    closed = True
Ejemplo n.º 21
0
class ClientWebsocket(object):
    """
    Represents a ClientWebsocket.
    """

    def __init__(self, address: Tuple[str, str, bool, str], *, reconnecting: bool = True):
        """
        :param type_: The :class:`wsproto.connection.ConnectionType` for this websocket.
        :param address: A 4-tuple of (host, port, ssl, endpoint).
        :param endpoint: The endpoint to open the connection to.
        :param reconnecting: If this websocket reconnects automatically. Only useful on the client.
        """

        # these are all used to construct the state object
        self._address = address

        self.state = None  # type: WSConnection
        self._ready = False
        self._reconnecting = reconnecting
        self.sock = None  # type: trio.socket.Socket

    @property
    def closed(self) -> bool:
        """
        :return: If this websocket is closed.
        """
        return self.state.closed

    def _create_ssl_ctx(self, sslp):
        if isinstance(sslp, ssl.SSLContext):
            return sslp
        ca = sslp.get('ca')
        capath = sslp.get('capath')
        hasnoca = ca is None and capath is None
        ctx = ssl.create_default_context(cafile=ca, capath=capath)
        ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True)
        ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
        if 'cert' in sslp:
            ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key'))
        if 'cipher' in sslp:
            ctx.set_ciphers(sslp['cipher'])
        ctx.options |= ssl.OP_NO_SSLv2
        ctx.options |= ssl.OP_NO_SSLv3
        return ctx

    async def open_connection(self):
        """
        Opens a connection, and performs the initial handshake.
        """
        _ssl = self._address[2]
        _sock = await trio.open_tcp_stream(self._address[0], self._address[1])
        if _ssl:
            if _ssl is True:
                _ssl = {}
            _ssl = self._create_ssl_ctx(_ssl)
            _sock = trio.ssl.SSLStream(_sock, _ssl, server_hostname=self._address[0], https_compatible=True)
            await _sock.do_handshake()
        self.sock = _sock
        self.state = WSConnection(ConnectionType.CLIENT, host=self._address[0],
                                  resource=self._address[3])
        res = self.state.bytes_to_send()
        await self.sock.send_all(res)

    async def __aiter__(self):
        # initiate the websocket
        await self.open_connection()
        # setup buffers
        buf_bytes = BytesIO()
        buf_text = StringIO()

        while self.sock is not None:
            data = await self.sock.receive_some(4096)
            self.state.receive_bytes(data)

            # do ping/pongs if needed
            to_send = self.state.bytes_to_send()
            if to_send:
                await self.sock.send_all(to_send)

            for event in self.state.events():
                if isinstance(event, events.ConnectionEstablished):
                    self._ready = True
                    yield WebsocketConnectionEstablished(event)

                elif isinstance(event, events.ConnectionClosed):
                    self._ready = False
                    yield WebsocketClosed(event)

                    if self._reconnecting:
                        await self.open_connection()
                    else:
                        return

                elif isinstance(event, events.DataReceived):
                    buf = buf_bytes if isinstance(event, events.BytesReceived) else buf_text
                    buf.write(event.data)

                    # yield events as appropriate
                    if event.message_finished:
                        buf.seek(0)
                        read = buf.read()
                        # empty buffer
                        buf.truncate(0)
                        buf.seek(0)

                        typ = WebsocketBytesMessage if isinstance(event, events.BytesReceived) \
                            else WebsocketTextMessage
                        yield typ(read)

                elif isinstance(event, events.ConnectionFailed):
                    self._ready = False

                    yield WebsocketConnectionFailed(event)

                    if self._reconnecting:
                        await self.open_connection()
                    else:
                        return

                else:
                    raise RuntimeError("I don't understand this message", event)

    async def send_message(self, data: Union[str, bytes]):
        """
        Sends a message on the websocket.

        :param data: The data to send. Either str or bytes.
        """
        self.state.send_data(data, final=True)
        return await self.sock.send_all(self.state.bytes_to_send())

    async def aclose(self, *, code: int = 1000, reason: str = "No reason",
                    allow_reconnects: bool = False):
        """
        Closes the websocket.

        :param code: The close code to use.
        :param reason: The close reason to use.

        If the websocket is marked as reconnecting:

        :param allow_reconnects: If the websocket can reconnect after this close.
        """
        # do NOT reconnect if we close explicitly and don't allow reconnects
        if not allow_reconnects:
            self._reconnecting = False

        self.state.close(code=code, reason=reason)
        to_send = self.state.bytes_to_send()
        await self.sock.send_all(to_send)
        await self.sock.aclose()
        self.sock = None
        self.state.receive_bytes(None)
Ejemplo n.º 22
0
async def wsproto_client_demo(host, port, use_ssl):
    '''
    Demonstrate wsproto:
    0) Open TCP connection
    1) Negotiate WebSocket opening handshake
    2) Send a message and display response
    3) Send ping and display pong
    4) Negotiate WebSocket closing handshake
    :param stream: a socket stream
    '''

    # 0) Open TCP connection
    print('[C] Connecting to {}:{}'.format(host, port))
    conn = await trio.open_tcp_stream(host, port)
    if use_ssl:
        conn = upgrade_stream_to_ssl(conn, host)

    # 1) Negotiate WebSocket opening handshake
    print('[C] Opening WebSocket')
    ws = WSConnection(ConnectionType.CLIENT, host=host, resource='server')

    events = ws.events()

    # Because this is a client WebSocket, wsproto has automatically queued up
    # a handshake, and we need to send it and wait for a response.
    await net_send_recv(ws, conn)
    event = next(events)
    if isinstance(event, ConnectionEstablished):
        print('[C] WebSocket negotiation complete')
    else:
        raise Exception(f'Expected ConnectionEstablished event! Got: {event}')

    # 2) Send a message and display response
    message = "wsproto is great" * 10
    print('[C] Sending message: {}'.format(message))
    ws.send_data(message)
    await net_send_recv(ws, conn)
    event = next(events)
    if isinstance(event, TextReceived):
        print('[C] Received message: {}'.format(event.data))
    else:
        raise Exception(f'Expected TextReceived event! Got: {event}')

    # 3) Send ping and display pong
    payload = b"table tennis"
    print('[C] Sending ping: {}'.format(payload))
    ws.ping(payload)
    await net_send_recv(ws, conn)
    event = next(events)
    if isinstance(event, PongReceived):
        print('[C] Received pong: {}'.format(event.payload))
    else:
        raise Exception(f'Expected PongReceived event! Got: {event}')

    # 4) Negotiate WebSocket closing handshake
    print('[C] Closing WebSocket')
    ws.close(code=1000, reason='sample reason')
    # After sending the closing frame, we won't get any more events. The server
    # should send a reply and then close the connection, so we need to receive
    # twice:
    await net_send_recv(ws, conn)
    await conn.aclose()
Ejemplo n.º 23
0
                              uri.netloc,
                              '%s?%s' % (uri.path, uri.query),
                              extensions=[PerMessageDeflate()])
    sock = socket.socket()
    sock.connect((uri.hostname, uri.port or 80))

    sock.sendall(connection.bytes_to_send())
    closed = False

    while not closed:
        try:
            data = sock.recv(65535)
        except CONNECTION_EXCEPTIONS:
            data = None
        connection.receive_bytes(data or None)
        for event in connection.events():
            if isinstance(event, DataReceived):
                connection.send_data(event.data, event.message_finished)
            elif isinstance(event, ConnectionClosed):
                closed = True
            # else:
            #     print("??", event)
        if data is None:
            break
        try:
            data = connection.bytes_to_send()
            sock.sendall(data)
        except CONNECTION_EXCEPTIONS:
            closed = True
            break
Ejemplo n.º 24
0
def handle_connection(stream):
    '''
    Handle a connection.

    The server operates a request/response cycle, so it performs a synchronous
    loop:

    1) Read data from network into wsproto
    2) Get next wsproto event
    3) Handle event
    4) Send data from wsproto to network

    :param stream: a socket stream
    '''
    ws = WSConnection(ConnectionType.SERVER)

    # events is a generator that yields websocket event objects. Usually you
    # would say `for event in ws.events()`, but the synchronous nature of this
    # server requires us to use next(event) instead so that we can interleave
    # the network I/O.
    events = ws.events()
    running = True

    while running:
        # 1) Read data from network
        in_data = stream.recv(RECEIVE_BYTES)
        print('Received {} bytes'.format(len(in_data)))
        ws.receive_bytes(in_data)

        # 2) Get next wsproto event
        try:
            event = next(events)
        except StopIteration:
            print('Client connection dropped unexpectedly')
            return

        # 3) Handle event
        if isinstance(event, ConnectionRequested):
            # Negotiate new WebSocket connection
            print('Accepting WebSocket upgrade')
            ws.accept(event)
        elif isinstance(event, ConnectionClosed):
            # Print log message and break out
            print('Connection closed: code={}/{} reason={}'.format(
                event.code.value, event.code.name, event.reason))
            running = False
        elif isinstance(event, TextReceived):
            # Reverse text and send it back to wsproto
            print('Received request and sending response')
            ws.send_data(event.data[::-1])
        elif isinstance(event, PingReceived):
            # wsproto handles ping events for you by placing a pong frame in
            # the outgoing buffer. You should not call pong() unless you want to
            # send an unsolicited pong frame.
            print('Received ping and sending pong')
        else:
            print('Unknown event: {!r}'.format(event))

        # 4) Send data from wsproto to network
        out_data = ws.bytes_to_send()
        print('Sending {} bytes'.format(len(out_data)))
        stream.send(out_data)
Ejemplo n.º 25
0
    def test_h11_somehow_loses_its_mind(self):
        ws = WSConnection(SERVER)
        ws._upgrade_connection.next_event = lambda: object()

        ws.receive_bytes(b'')
        assert isinstance(next(ws.events()), ConnectionFailed)
Ejemplo n.º 26
0
class ClientWebsocket(object):
    """
    Represents a ClientWebsocket.
    """
    def __init__(self,
                 address: Tuple[str, str, bool, str],
                 *,
                 reconnecting: bool = True):
        """
        :param type_: The :class:`wsproto.connection.ConnectionType` for this websocket.
        :param address: A 4-tuple of (host, port, ssl, endpoint).
        :param endpoint: The endpoint to open the connection to.
        :param reconnecting: If this websocket reconnects automatically. Only useful on the client.
        """

        # these are all used to construct the state object
        self._address = address

        self.state = None  # type: WSConnection
        self._ready = False
        self._reconnecting = reconnecting
        self.sock = None  # type: multio.SocketWrapper

    @property
    def closed(self) -> bool:
        """
        :return: If this websocket is closed.
        """
        return self.state.closed

    async def open_connection(self):
        """
        Opens a connection, and performs the initial handshake.
        """
        _sock = await multio.asynclib.open_connection(self._address[0],
                                                      self._address[1],
                                                      ssl=self._address[2])
        self.sock = multio.SocketWrapper(_sock)
        self.state = WSConnection(ConnectionType.CLIENT,
                                  host=self._address[0],
                                  resource=self._address[3])
        res = self.state.bytes_to_send()
        await self.sock.sendall(res)

    async def __aiter__(self):
        # initiate the websocket
        await self.open_connection()
        # setup buffers
        buf_bytes = BytesIO()
        buf_text = StringIO()

        while True:
            data = await self.sock.recv(4096)
            self.state.receive_bytes(data)

            # do ping/pongs if needed
            to_send = self.state.bytes_to_send()
            if to_send:
                await self.sock.sendall(to_send)

            for event in self.state.events():
                if isinstance(event, events.ConnectionEstablished):
                    self._ready = True
                    yield WebsocketConnectionEstablished(event)

                elif isinstance(event, events.ConnectionClosed):
                    self._ready = False
                    yield WebsocketClosed(event)

                    if self._reconnecting:
                        await self.open_connection()
                    else:
                        return

                elif isinstance(event, events.DataReceived):
                    buf = buf_bytes if isinstance(
                        event, events.BytesReceived) else buf_text
                    buf.write(event.data)

                    # yield events as appropriate
                    if event.message_finished:
                        buf.seek(0)
                        read = buf.read()
                        # empty buffer
                        buf.truncate(0)
                        buf.seek(0)

                        typ = WebsocketBytesMessage if isinstance(event, events.BytesReceived) \
                            else WebsocketTextMessage
                        yield typ(read)

                elif isinstance(event, events.ConnectionFailed):
                    self._ready = False

                    yield WebsocketConnectionFailed(event)

                    if self._reconnecting:
                        await self.open_connection()
                    else:
                        return

    async def send_message(self, data: Union[str, bytes]):
        """
        Sends a message on the websocket.

        :param data: The data to send. Either str or bytes.
        """
        self.state.send_data(data, final=True)
        return await self.sock.sendall(self.state.bytes_to_send())

    async def close(self,
                    *,
                    code: int = 1000,
                    reason: str = "No reason",
                    allow_reconnects: bool = False):
        """
        Closes the websocket.

        :param code: The close code to use.
        :param reason: The close reason to use.

        If the websocket is marked as reconnecting:

        :param allow_reconnects: If the websocket can reconnect after this close.
        """
        # do NOT reconnect if we close explicitly and don't allow reconnects
        if not allow_reconnects:
            self._reconnecting = False

        self.state.close(code=code, reason=reason)
        to_send = self.state.bytes_to_send()
        await self.sock.sendall(to_send)
        await self.sock.close()
        self.state.receive_bytes(None)
Ejemplo n.º 27
0
def wsproto_demo(host, port):
    '''
    Demonstrate wsproto:

    0) Open TCP connection
    1) Negotiate WebSocket opening handshake
    2) Send a message and display response
    3) Send ping and display pong
    4) Negotiate WebSocket closing handshake

    :param stream: a socket stream
    '''

    # 0) Open TCP connection
    print('Connecting to {}:{}'.format(host, port))
    conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    conn.connect((host, port))

    # 1) Negotiate WebSocket opening handshake
    print('Opening WebSocket')
    ws = WSConnection(ConnectionType.CLIENT, host=host, resource='server')

    # events is a generator that yields websocket event objects. Usually you
    # would say `for event in ws.events()`, but the synchronous nature of this
    # client requires us to use next(event) instead so that we can interleave
    # the network I/O. It will raise StopIteration when it runs out of events
    # (i.e. needs more network data), but since this script is synchronous, we
    # will explicitly resume the generator whenever we have new network data.
    events = ws.events()

    # Because this is a client WebSocket, wsproto has automatically queued up
    # a handshake, and we need to send it and wait for a response.
    net_send_recv(ws, conn)
    event = next(events)
    if isinstance(event, ConnectionEstablished):
        print('WebSocket negotiation complete')
    else:
        raise Exception('Expected ConnectionEstablished event!')

    # 2) Send a message and display response
    message = "wsproto is great"
    print('Sending message: {}'.format(message))
    ws.send_data(message)
    net_send_recv(ws, conn)
    event = next(events)
    if isinstance(event, TextReceived):
        print('Received message: {}'.format(event.data))
    else:
        raise Exception('Expected TextReceived event!')

    # 3) Send ping and display pong
    payload = b"table tennis"
    print('Sending ping: {}'.format(payload))
    ws.ping(payload)
    net_send_recv(ws, conn)
    event = next(events)
    if isinstance(event, PongReceived):
        print('Received pong: {}'.format(event.payload))
    else:
        raise Exception('Expected PongReceived event!')

    # 4) Negotiate WebSocket closing handshake
    print('Closing WebSocket')
    ws.close(code=1000, reason='sample reason')
    # After sending the closing frame, we won't get any more events. The server
    # should send a reply and then close the connection, so we need to receive
    # twice:
    net_send_recv(ws, conn)
    conn.shutdown(socket.SHUT_WR)
    net_recv(ws, conn)
Ejemplo n.º 28
0
class WebSocketEndpoint(AbstractEndpoint):
    """
    Implements websocket endpoints.

    Subprotocol negotiation is currently not supported.
    """

    __slots__ = ('ctx', '_client', '_ws')

    def __init__(self, ctx: Context, client: BaseHTTPClientConnection):
        self.ctx = ctx
        self._client = client
        self._ws = WSConnection(ConnectionType.SERVER)

    def _process_ws_events(self):
        for event in self._ws.events():
            if isinstance(event, ConnectionRequested):
                self._ws.accept(event)
                self.on_connect()
            elif isinstance(event, DataReceived):
                self.on_data(event.data)
            elif isinstance(event, ConnectionClosed):
                self.on_close()

        bytes_to_send = self._ws.bytes_to_send()
        if bytes_to_send:
            self._client.write(bytes_to_send)

    def begin_request(self, request: HTTPRequest):
        trailing_data = self._client.upgrade()
        self._ws.receive_bytes(trailing_data)
        self._process_ws_events()

    def receive_body_data(self, data: bytes) -> None:
        self._ws.receive_bytes(data)
        self._process_ws_events()

    def send_message(self, payload: Union[str, bytes]) -> None:
        """
        Send a message to the client.

        :param payload: either a unicode string or a bytestring

        """
        self._ws.send_data(payload)
        bytes_to_send = self._ws.bytes_to_send()
        self._client.write(bytes_to_send)

    def close(self) -> None:
        """Close the websocket."""
        self._ws.close()
        self._process_ws_events()

    def on_connect(self) -> None:
        """Called when the websocket handshake has been done."""

    def on_close(self) -> None:
        """Called when the connection has been closed."""

    def on_data(self, data: bytes) -> None:
        """Called when there is new data from the peer."""