コード例 #1
0
    def process_handshake_header(self, headers):
        """
        Read the upgrade handshake's response headers and
        validate them against :rfc:`6455`.
        """
        protocols = []
        extensions = []

        headers = headers.strip()

        for header_line in headers.split(b'\r\n'):
            header, value = header_line.split(b':', 1)
            header = header.strip().lower()
            value = value.strip().lower()

            if header == b'upgrade' and value != b'websocket':
                raise HandshakeError("Invalid Upgrade header: %s" % value)

            elif header == b'connection' and value != b'upgrade':
                raise HandshakeError("Invalid Connection header: %s" % value)

            elif header == b'sec-websocket-accept':
                match = b64encode(sha1(self.key + WS_KEY).digest())
                if value != match.lower():
                    raise HandshakeError("Invalid challenge response: %s" %
                                         value)

            elif header == b'sec-websocket-protocol':
                protocols = ','.join(value)

            elif header == b'sec-websocket-extensions':
                extensions = ','.join(value)

        return protocols, extensions
コード例 #2
0
 def process_response_line(self, response_line):
     """
     Ensure that we received a HTTP `101` status code in
     response to our request and if not raises :exc:`HandshakeError`.
     """
     protocol, code, status = response_line.split(b' ', 2)
     if code != b'101':
         raise HandshakeError("Invalid response status: %s %s" %
                              (code, status))
コード例 #3
0
    def connect(self):
        """
        Connects this websocket and starts the upgrade handshake
        with the remote endpoint.
        """
        if self.scheme == "wss":
            # default port is now 443; upgrade self.sender to send ssl
            self.sock = ssl.wrap_socket(self.sock, **self.ssl_options)
            self._is_secure = True

        self.sock.settimeout(10.0)
        self.sock.connect(self.bind_addr)

        self._write(self.handshake_request)

        response = b''
        doubleCLRF = b'\r\n\r\n'
        while True:
            bytes = self.sock.recv(128)
            if not bytes:
                break
            response += bytes
            if doubleCLRF in response:
                break

        if not response:
            self.close_connection()
            raise HandshakeError("Invalid response")

        headers, _, body = response.partition(doubleCLRF)
        response_line, _, headers = headers.partition(b'\r\n')

        try:
            self.process_response_line(response_line)
            self.protocols, self.extensions = self.process_handshake_header(
                headers)
        except HandshakeError:
            self.close_connection()
            raise

        self.handshake_ok()
        if body:
            self.process(body)
コード例 #4
0
    def upgrade(self,
                protocols=None,
                extensions=None,
                version=WS_VERSION,
                handler_cls=WebSocket,
                heartbeat_freq=None):
        """
        Performs the upgrade of the connection to the WebSocket
        protocol.

        The provided protocols may be a list of WebSocket
        protocols supported by the instance of the tool.

        When no list is provided and no protocol is either
        during the upgrade, then the protocol parameter is
        not taken into account. On the other hand,
        if the protocol from the handshake isn't part
        of the provided list, the upgrade fails immediatly.
        """
        request = cherrypy.serving.request
        request.process_request_body = False

        ws_protocols = None
        ws_location = None
        ws_version = version
        ws_key = None
        ws_extensions = []

        if request.method != 'GET':
            raise HandshakeError('HTTP method must be a GET')

        for key, expected_value in [('Upgrade', 'websocket'),
                                    ('Connection', 'upgrade')]:
            actual_value = request.headers.get(key, '').lower()
            if not actual_value:
                raise HandshakeError('Header %s is not defined' % key)
            if expected_value not in actual_value:
                raise HandshakeError('Illegal value for header %s: %s' %
                                     (key, actual_value))

        version = request.headers.get('Sec-WebSocket-Version')
        supported_versions = ', '.join([str(v) for v in ws_version])
        version_is_valid = False
        if version:
            try:
                version = int(version)
            except:
                pass
            else:
                version_is_valid = version in ws_version

        if not version_is_valid:
            cherrypy.response.headers[
                'Sec-WebSocket-Version'] = supported_versions
            raise HandshakeError('Unhandled or missing WebSocket version')

        key = request.headers.get('Sec-WebSocket-Key')
        if key:
            ws_key = base64.b64decode(key.encode('utf-8'))
            if len(ws_key) != 16:
                raise HandshakeError("WebSocket key's length is invalid")

        protocols = protocols or []
        subprotocols = request.headers.get('Sec-WebSocket-Protocol')
        if subprotocols:
            ws_protocols = []
            for s in subprotocols.split(','):
                s = s.strip()
                if s in protocols:
                    ws_protocols.append(s)

        exts = extensions or []
        extensions = request.headers.get('Sec-WebSocket-Extensions')
        if extensions:
            for ext in extensions.split(','):
                ext = ext.strip()
                if ext in exts:
                    ws_extensions.append(ext)

        location = []
        include_port = False
        if request.scheme == "https":
            location.append("wss://")
            include_port = request.local.port != 443
        else:
            location.append("ws://")
            include_port = request.local.port != 80
        location.append('localhost')
        if include_port:
            location.append(":%d" % request.local.port)
        location.append(request.path_info)
        if request.query_string != "":
            location.append("?%s" % request.query_string)
        ws_location = ''.join(location)

        response = cherrypy.serving.response
        response.stream = True
        response.status = '101 Switching Protocols'
        response.headers['Content-Type'] = 'text/plain'
        response.headers['Upgrade'] = 'websocket'
        response.headers['Connection'] = 'Upgrade'
        response.headers['Sec-WebSocket-Version'] = str(version)
        response.headers['Sec-WebSocket-Accept'] = base64.b64encode(
            sha1(key.encode('utf-8') + WS_KEY).digest())
        if ws_protocols:
            response.headers['Sec-WebSocket-Protocol'] = ', '.join(
                ws_protocols)
        if ws_extensions:
            response.headers['Sec-WebSocket-Extensions'] = ','.join(
                ws_extensions)

        addr = (request.remote.ip, request.remote.port)
        rfile = request.rfile.rfile
        if isinstance(rfile, KnownLengthRFile):
            rfile = rfile.rfile

        ws_conn = get_connection(rfile)
        request.ws_handler = handler_cls(ws_conn,
                                         ws_protocols,
                                         ws_extensions,
                                         request.wsgi_environ.copy(),
                                         heartbeat_freq=heartbeat_freq)
コード例 #5
0
    def __call__(self, environ, start_response):
        if environ.get('REQUEST_METHOD') != 'GET':
            raise HandshakeError('HTTP method must be a GET')

        for key, expected_value in [('HTTP_UPGRADE', 'websocket'),
                                    ('HTTP_CONNECTION', 'upgrade')]:
            actual_value = environ.get(key, '').lower()
            if not actual_value:
                raise HandshakeError('Header %s is not defined' % key)
            if expected_value not in actual_value:
                raise HandshakeError('Illegal value for header %s: %s' %
                                     (key, actual_value))

        key = environ.get('HTTP_SEC_WEBSOCKET_KEY')
        if key:
            ws_key = base64.b64decode(key.encode('utf-8'))
            if len(ws_key) != 16:
                raise HandshakeError("WebSocket key's length is invalid")

        version = environ.get('HTTP_SEC_WEBSOCKET_VERSION')
        supported_versions = b', '.join([unicode(v).encode('utf-8') for v in WS_VERSION])
        version_is_valid = False
        if version:
            try: version = int(version)
            except: pass
            else: version_is_valid = version in WS_VERSION

        if not version_is_valid:
            environ['websocket.version'] = unicode(version).encode('utf-8')
            raise HandshakeError('Unhandled or missing WebSocket version')

        ws_protocols = []
        protocols = self.protocols or []
        subprotocols = environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL')
        if subprotocols:
            for s in subprotocols.split(','):
                s = s.strip()
                if s in protocols:
                    ws_protocols.append(s)

        ws_extensions = []
        exts = self.extensions or []
        extensions = environ.get('HTTP_SEC_WEBSOCKET_EXTENSIONS')
        if extensions:
            for ext in extensions.split(','):
                ext = ext.strip()
                if ext in exts:
                    ws_extensions.append(ext)

        accept_value = base64.b64encode(sha1(key.encode('utf-8') + WS_KEY).digest())
        if py3k: accept_value = accept_value.decode('utf-8')
        upgrade_headers = [
            ('Upgrade', 'websocket'),
            ('Connection', 'Upgrade'),
            ('Sec-WebSocket-Version', '%s' % version),
            ('Sec-WebSocket-Accept', accept_value),
            ]
        if ws_protocols:
            upgrade_headers.append(('Sec-WebSocket-Protocol', ', '.join(ws_protocols)))
        if ws_extensions:
            upgrade_headers.append(('Sec-WebSocket-Extensions', ','.join(ws_extensions)))

        start_response("101 Switching Protocols", upgrade_headers)

        self.make_websocket(environ['ws4py.socket'],
                            ws_protocols,
                            ws_extensions,
                            environ)

        return []