예제 #1
0
    def __init__(self, ctx, handshake_flow):
        super().__init__(ctx)
        self.handshake_flow = handshake_flow
        self.flow: WebSocketFlow = None

        self.client_frame_buffer = []
        self.server_frame_buffer = []

        self.connections: dict[object, WSConnection] = {}

        client_extensions = []
        server_extensions = []
        if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
            if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']:
                client_extensions = [PerMessageDeflate()]
                server_extensions = [PerMessageDeflate()]
        self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER)
        self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT)

        if client_extensions:
            client_extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions'])
        if server_extensions:
            server_extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions'])

        request = Request(extensions=client_extensions, host=handshake_flow.request.host, target=handshake_flow.request.path)
        data = self.connections[self.server_conn].send(request)
        self.connections[self.client_conn].receive_data(data)

        event = next(self.connections[self.client_conn].events())
        assert isinstance(event, events.Request)

        data = self.connections[self.client_conn].send(AcceptConnection(extensions=server_extensions))
        self.connections[self.server_conn].receive_data(data)
        assert isinstance(next(self.connections[self.server_conn].events()), events.AcceptConnection)
예제 #2
0
def new_conn(sock):
    global count
    print("test_server.py received connection {}".format(count))
    count += 1
    ws = WSConnection(SERVER, extensions=[PerMessageDeflate()])
    closed = False
    while not closed:
        try:
            data = sock.recv(65535)
        except socket.error:
            data = None

        ws.receive_bytes(data or None)

        for event in ws.events():
            if isinstance(event, ConnectionRequested):
                ws.accept(event)
            elif isinstance(event, DataReceived):
                ws.send_data(event.data, event.message_finished)
            elif isinstance(event, ConnectionClosed):
                closed = True

        if not data:
            closed = True

        try:
            data = ws.bytes_to_send()
            sock.sendall(data)
        except socket.error:
            closed = True

    sock.close()
예제 #3
0
 async def asgi_send(self, message: dict) -> None:
     """Called by the ASGI instance to send a message."""
     if message["type"] == "websocket.accept" and self.state == ASGIWebsocketState.HANDSHAKE:
         await self.asend(AcceptConnection(extensions=[PerMessageDeflate()]))
         self.state = ASGIWebsocketState.CONNECTED
         self.config.access_logger.access(
             self.scope, {"status": 101, "headers": []}, time() - self.start_time
         )
     elif (
         message["type"] == "websocket.http.response.start"
         and self.state == ASGIWebsocketState.HANDSHAKE
     ):
         self.response = message
         self.config.access_logger.access(self.scope, self.response, time() - self.start_time)
     elif message["type"] == "websocket.http.response.body" and self.state in {
         ASGIWebsocketState.HANDSHAKE,
         ASGIWebsocketState.RESPONSE,
     }:
         await self._asgi_send_rejection(message)
     elif message["type"] == "websocket.send" and self.state == ASGIWebsocketState.CONNECTED:
         data: Union[bytes, str]
         if message.get("bytes") is not None:
             await self.asend(BytesMessage(data=bytes(message["bytes"])))
         elif not isinstance(message["text"], str):
             raise TypeError(f"{message['text']} should be a str")
         else:
             await self.asend(TextMessage(data=message["text"]))
     elif message["type"] == "websocket.close" and self.state == ASGIWebsocketState.HANDSHAKE:
         await self.send_http_error(403)
         self.state = ASGIWebsocketState.HTTPCLOSED
     elif message["type"] == "websocket.close":
         await self.asend(CloseConnection(code=int(message["code"])))
         self.state = ASGIWebsocketState.CLOSED
     else:
         raise UnexpectedMessage(self.state, message["type"])
예제 #4
0
    def accept(
        self, subprotocol: Optional[str]
    ) -> Tuple[int, List[Tuple[bytes, bytes]], Connection]:
        headers = []
        if subprotocol is not None:
            if subprotocol not in self.subprotocols:
                raise Exception("Invalid Subprotocol")
            else:
                headers.append((b"sec-websocket-protocol", subprotocol.encode()))

        extensions = [PerMessageDeflate()]
        accepts = None
        if False and self.extensions is not None:
            accepts = server_extensions_handshake(self.extensions, extensions)

        if accepts:
            headers.append((b"sec-websocket-extensions", accepts))

        if self.key is not None:
            headers.append((b"sec-websocket-accept", generate_accept_token(self.key)))

        status_code = 200
        if self.http_version == "1.1":
            headers.extend([(b"upgrade", b"WebSocket"), (b"connection", b"Upgrade")])
            status_code = 101

        return status_code, headers, Connection(ConnectionType.SERVER, extensions)
예제 #5
0
    def __init__(self, ws_handler, host=None, port=None, *,
                 path=None, create_protocol=None,
                 timeout=10, max_size=2 ** 20,
                 origins=None, extensions=None, subprotocols=None,
                 extra_headers=None, compression='deflate', ssl=None):
        
        if create_protocol is None:
            create_protocol = WebSocketServerProtocol
        
        if compression == 'deflate':
            if extensions is None:
                extensions = []
            if not any(
                extension_factory.name == PerMessageDeflate.name
                for extension_factory in extensions
            ):
                extensions.append(PerMessageDeflate(
                    client_max_window_bits=True,
                ))
        elif compression is not None:
            raise ValueError("Unsupported compression: {}".format(compression))

        self.factory = lambda: create_protocol(
            ws_handler,
            host=host, port=port, secure=ssl,
            timeout=timeout, max_size=max_size, 
            origins=origins, extensions=extensions, subprotocols=subprotocols,
            extra_headers=extra_headers,
        )

        self._port = port
        self._host = host
        self.path = path
        self.ssl = ssl
예제 #6
0
def new_conn(reader, writer):
    ws = WSConnection(SERVER, extensions=[PerMessageDeflate()])
    closed = False
    while not closed:
        try:
            data = yield from reader.read(65535)
        except ConnectionError:
            data = None

        ws.receive_bytes(data or None)

        for event in ws.events():
            if isinstance(event, ConnectionRequested):
                ws.accept(event)
            elif isinstance(event, DataReceived):
                ws.send_data(event.data, event.final)
            elif isinstance(event, ConnectionClosed):
                closed = True
            if data is None:
                break

            try:
                data = ws.bytes_to_send()
                writer.write(data)
                yield from writer.drain()
            except (ConnectionError, OSError):
                closed = True

            if closed:
                break

    writer.close()
예제 #7
0
def new_conn(reader, writer):
    global count
    print("test_server.py received connection {}".format(count))
    count += 1
    ws = WSConnection(SERVER, extensions=[PerMessageDeflate()])
    closed = False
    while not closed:
        try:
            data = yield from reader.read(65535)
        except ConnectionError:
            data = None

        ws.receive_bytes(data or None)

        for event in ws.events():
            if isinstance(event, ConnectionRequested):
                ws.accept(event)
            elif isinstance(event, DataReceived):
                ws.send_data(event.data, event.message_finished)
            elif isinstance(event, ConnectionClosed):
                closed = True

        if not data:
            closed = True

        try:
            data = ws.bytes_to_send()
            writer.write(data)
            yield from writer.drain()
        except (ConnectionError, OSError):
            closed = True

    writer.close()
예제 #8
0
    def __init__(self, ctx, handshake_flow):
        super().__init__(ctx)
        self.handshake_flow = handshake_flow
        self.flow: WebSocketFlow = None

        self.client_frame_buffer = []
        self.server_frame_buffer = []

        self.connections: dict[object, WSConnection] = {}

        extensions = []
        if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
            if PerMessageDeflate.name in handshake_flow.response.headers[
                    'Sec-WebSocket-Extensions']:
                extensions = [PerMessageDeflate()]
        self.connections[self.client_conn] = WSConnection(
            ConnectionType.SERVER, extensions=extensions)
        self.connections[self.server_conn] = WSConnection(
            ConnectionType.CLIENT,
            host=handshake_flow.request.host,
            resource=handshake_flow.request.path,
            extensions=extensions)
        if extensions:
            for conn in self.connections.values():
                conn.extensions[0].finalize(
                    conn, handshake_flow.response.
                    headers['Sec-WebSocket-Extensions'])

        data = self.connections[self.server_conn].bytes_to_send()
        self.connections[self.client_conn].receive_bytes(data)

        event = next(self.connections[self.client_conn].events())
        assert isinstance(event, events.ConnectionRequested)

        self.connections[self.client_conn].accept(event)
        self.connections[self.server_conn].receive_bytes(
            self.connections[self.client_conn].bytes_to_send())
        assert isinstance(next(self.connections[self.server_conn].events()),
                          events.ConnectionEstablished)
예제 #9
0
def new_conn(sock: socket.socket) -> None:
    global count
    print("test_server.py received connection {}".format(count))
    count += 1
    ws = WSConnection(SERVER)
    closed = False
    while not closed:
        try:
            data: Optional[bytes] = sock.recv(65535)
        except socket.error:
            data = None

        ws.receive_data(data or None)

        outgoing_data = b""
        for event in ws.events():
            if isinstance(event, Request):
                outgoing_data += ws.send(
                    AcceptConnection(extensions=[PerMessageDeflate()]))
            elif isinstance(event, Message):
                outgoing_data += ws.send(
                    Message(data=event.data,
                            message_finished=event.message_finished))
            elif isinstance(event, Ping):
                outgoing_data += ws.send(event.response())
            elif isinstance(event, CloseConnection):
                closed = True
                if ws.state is not ConnectionState.CLOSED:
                    outgoing_data += ws.send(event.response())

        if not data:
            closed = True

        try:
            sock.sendall(outgoing_data)
        except socket.error:
            closed = True

    sock.close()
예제 #10
0
    def accept(
        self,
        subprotocol: Optional[str],
        additional_headers: Iterable[Tuple[bytes, bytes]],
    ) -> Tuple[int, List[Tuple[bytes, bytes]], Connection]:
        headers = []
        if subprotocol is not None:
            if subprotocol not in self.subprotocols:
                raise Exception("Invalid Subprotocol")
            else:
                headers.append(
                    (b"sec-websocket-protocol", subprotocol.encode()))

        extensions: List[Extension] = [PerMessageDeflate()]
        accepts = None
        if self.extensions is not None:
            accepts = server_extensions_handshake(self.extensions, extensions)

        if accepts:
            headers.append((b"sec-websocket-extensions", accepts))

        if self.key is not None:
            headers.append(
                (b"sec-websocket-accept", generate_accept_token(self.key)))

        status_code = 200
        if self.http_version == "1.1":
            headers.extend([(b"upgrade", b"WebSocket"),
                            (b"connection", b"Upgrade")])
            status_code = 101

        for name, value in additional_headers:
            if b"sec-websocket-protocol" == name or name.startswith(b":"):
                raise Exception(f"Invalid additional header, {name.decode()}")

            headers.append((name, value))

        return status_code, headers, Connection(ConnectionType.SERVER,
                                                extensions)
예제 #11
0
    async def send(self, message):
        await self.writable.wait()

        message_type = message["type"]

        if not self.handshake_complete:
            if message_type == "websocket.accept":
                self.logger.info(
                    '%s - "WebSocket %s" [accepted]',
                    self.scope["client"],
                    self.scope["root_path"] + self.scope["path"],
                )
                self.handshake_complete = True
                subprotocol = message.get("subprotocol")
                output = self.conn.send(
                    wsproto.events.AcceptConnection(
                        subprotocol=subprotocol,
                        extensions=[PerMessageDeflate()]))
                self.transport.write(output)

            elif message_type == "websocket.close":
                self.queue.put_nowait({
                    "type": "websocket.disconnect",
                    "code": None
                })
                self.logger.info(
                    '%s - "WebSocket %s" 403',
                    self.scope["client"],
                    self.scope["root_path"] + self.scope["path"],
                )
                self.handshake_complete = True
                self.close_sent = True
                msg = h11.Response(status_code=403, headers=[])
                output = self.conn.send(msg)
                msg = h11.EndOfMessage()
                output += self.conn.send(msg)
                self.transport.write(output)
                self.transport.close()

            else:
                msg = "Expected ASGI message 'websocket.accept' or 'websocket.close', but got '%s'."
                raise RuntimeError(msg % message_type)

        elif not self.close_sent:
            if message_type == "websocket.send":
                bytes_data = message.get("bytes")
                text_data = message.get("text")
                data = text_data if bytes_data is None else bytes_data
                output = self.conn.send(wsproto.events.Message(data=data))
                if not self.transport.is_closing():
                    self.transport.write(output)

            elif message_type == "websocket.close":
                self.close_sent = True
                code = message.get("code", 1000)
                self.queue.put_nowait({
                    "type": "websocket.disconnect",
                    "code": code
                })
                output = self.conn.send(
                    wsproto.events.CloseConnection(code=code))
                if not self.transport.is_closing():
                    self.transport.write(output)
                    self.transport.close()

            else:
                msg = "Expected ASGI message 'websocket.send' or 'websocket.close', but got '%s'."
                raise RuntimeError(msg % message_type)

        else:
            msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
            raise RuntimeError(msg % message_type)
예제 #12
0
                    connection.close()
            try:
                sock.sendall(connection.bytes_to_send())
            except CONNECTION_EXCEPTIONS:
                break

    sock.close()
    return case_count


def run_case(server, case, agent):
    uri = urlparse(server + '/runCase?case=%d&agent=%s' % (case, agent))
    connection = WSConnection(CLIENT,
                              uri.netloc,
                              '%s?%s' % (uri.path, uri.query),
                              extensions=[PerMessageDeflate()])
    sock = socket.socket()
    sock.connect((uri.hostname, uri.port or 80))

    sock.sendall(connection.bytes_to_send())
    closed = False

    while not closed:
        try:
            data = sock.recv(65535)
        except CONNECTION_EXCEPTIONS:
            data = None
        connection.receive_bytes(data or None)
        for event in connection.events():
            if isinstance(event, DataReceived):
                connection.send_data(event.data, event.message_finished)
+from wsproto import events, WSConnection
+from wsproto.connection import ConnectionType
+from wsproto.events import AcceptConnection, CloseConnection, Message, Ping, Request
 from wsproto.extensions import PerMessageDeflate
 
 from mitmproxy import exceptions
@@ -52,51 +53,52 @@ class WebSocketLayer(base.Layer):
 
         self.connections: dict[object, WSConnection] = {}
 
-        extensions = []
+        client_extensions = []
+        server_extensions = []
         if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
             if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']:
-                extensions = [PerMessageDeflate()]
-        self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER,
-                                                          extensions=extensions)
-        self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT,
-                                                          host=handshake_flow.request.host,
-                                                          resource=handshake_flow.request.path,
-                                                          extensions=extensions)
-        if extensions:
-            for conn in self.connections.values():
-                conn.extensions[0].finalize(conn, handshake_flow.response.headers['Sec-WebSocket-Extensions'])
+                client_extensions = [PerMessageDeflate()]
+                server_extensions = [PerMessageDeflate()]
+        self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER)
+        self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT)
 
-        data = self.connections[self.server_conn].bytes_to_send()