def test_parse_with_allow_quoted_string(self):
        for formatted_string, definition in _TEST_TOKEN_EXTENSION_DATA:
            self._verify_extension_list(
                definition, parse_extensions(formatted_string,
                                             allow_quoted_string=True))

        for formatted_string, definition in _TEST_QUOTED_EXTENSION_DATA:
            self._verify_extension_list(
                definition, parse_extensions(formatted_string,
                                             allow_quoted_string=True))
    def test_parse_with_allow_quoted_string(self):
        for formatted_string, definition in _TEST_TOKEN_EXTENSION_DATA:
            self._verify_extension_list(
                definition,
                parse_extensions(formatted_string, allow_quoted_string=True))

        for formatted_string, definition in _TEST_QUOTED_EXTENSION_DATA:
            self._verify_extension_list(
                definition,
                parse_extensions(formatted_string, allow_quoted_string=True))
    def test_parse_redundant_data_with_allow_quoted_string(self):
        for (formatted_string,
             definition) in _TEST_REDUNDANT_TOKEN_EXTENSION_DATA:
            self._verify_extension_list(
                definition,
                parse_extensions(formatted_string, allow_quoted_string=True))

        for (formatted_string,
             definition) in _TEST_REDUNDANT_QUOTED_EXTENSION_DATA:
            self._verify_extension_list(
                definition,
                parse_extensions(formatted_string, allow_quoted_string=True))
    def test_parse_redundant_data_with_allow_quoted_string(self):
        for (formatted_string,
             definition) in _TEST_REDUNDANT_TOKEN_EXTENSION_DATA:
            self._verify_extension_list(
                definition, parse_extensions(formatted_string,
                                             allow_quoted_string=True))

        for (formatted_string,
             definition) in _TEST_REDUNDANT_QUOTED_EXTENSION_DATA:
            self._verify_extension_list(
                definition, parse_extensions(formatted_string,
                                             allow_quoted_string=True))
    def test_parse(self):
        for formatted_string, definition in _TEST_TOKEN_EXTENSION_DATA:
            self._verify_extension_list(
                definition,
                parse_extensions(formatted_string, allow_quoted_string=False))

        for formatted_string, unused_definition in _TEST_QUOTED_EXTENSION_DATA:
            self.assertRaises(ExtensionParsingException, parse_extensions,
                              formatted_string, False)
    def test_parse(self):
        for formatted_string, definition in _TEST_TOKEN_EXTENSION_DATA:
            self._verify_extension_list(
                definition, parse_extensions(formatted_string,
                                             allow_quoted_string=False))

        for formatted_string, unused_definition in _TEST_QUOTED_EXTENSION_DATA:
            self.assertRaises(
                ExtensionParsingException, parse_extensions,
                formatted_string, False)
    def test_parse_redundant_data(self):
        for (formatted_string,
             definition) in _TEST_REDUNDANT_TOKEN_EXTENSION_DATA:
            self._verify_extension_list(
                definition,
                parse_extensions(formatted_string, allow_quoted_string=False))

        for (formatted_string,
             definition) in _TEST_REDUNDANT_QUOTED_EXTENSION_DATA:
            self.assertRaises(ExtensionParsingException, parse_extensions,
                              formatted_string, False)
    def test_parse_redundant_data(self):
        for (formatted_string,
             definition) in _TEST_REDUNDANT_TOKEN_EXTENSION_DATA:
            self._verify_extension_list(
                definition, parse_extensions(formatted_string,
                                             allow_quoted_string=False))

        for (formatted_string,
             definition) in _TEST_REDUNDANT_QUOTED_EXTENSION_DATA:
            self.assertRaises(
                ExtensionParsingException, parse_extensions,
                formatted_string, False)
Esempio n. 9
0
    def _parse_extensions(self):
        extensions_header = self._request.headers_in.get(
            common.SEC_WEBSOCKET_EXTENSIONS_HEADER)
        if not extensions_header:
            self._request.ws_requested_extensions = None
            return

        try:
            self._request.ws_requested_extensions = common.parse_extensions(
                extensions_header)
        except common.ExtensionParsingException, e:
            raise HandshakeException(
                'Failed to parse Sec-WebSocket-Extensions header: %r' % e)
Esempio n. 10
0
    def _parse_extensions(self):
        extensions_header = self._request.headers_in.get(
            common.SEC_WEBSOCKET_EXTENSIONS_HEADER)
        if not extensions_header:
            self._request.ws_requested_extensions = None
            return

        try:
            self._request.ws_requested_extensions = common.parse_extensions(
                extensions_header)
        except common.ExtensionParsingException, e:
            raise HandshakeException(
                'Failed to parse Sec-WebSocket-Extensions header: %r' % e)
Esempio n. 11
0
    def _parse_extensions(self):
        extensions_header = self._request.headers_in.get(
            common.SEC_WEBSOCKET_EXTENSIONS_HEADER)
        if not extensions_header:
            self._request.ws_requested_extensions = None
            return

        if self._request.ws_version is common.VERSION_HYBI08:
            allow_quoted_string=False
        else:
            allow_quoted_string=True
        try:
            self._request.ws_requested_extensions = common.parse_extensions(
                extensions_header, allow_quoted_string=allow_quoted_string)
        except common.ExtensionParsingException, e:
            raise HandshakeException(
                'Failed to parse Sec-WebSocket-Extensions header: %r' % e)
Esempio n. 12
0
File: hybi.py Progetto: GhostQ/GitSB
    def _parse_extensions(self):
        extensions_header = self._request.headers_in.get(
            common.SEC_WEBSOCKET_EXTENSIONS_HEADER)
        if not extensions_header:
            self._request.ws_requested_extensions = None
            return

        if self._request.ws_version is common.VERSION_HYBI08:
            allow_quoted_string = False
        else:
            allow_quoted_string = True
        try:
            self._request.ws_requested_extensions = common.parse_extensions(
                extensions_header, allow_quoted_string=allow_quoted_string)
        except common.ExtensionParsingException, e:
            raise HandshakeException(
                'Failed to parse Sec-WebSocket-Extensions header: %r' % e)
Esempio n. 13
0
    def _parse_extensions(self):
        extensions_header = self._request.headers_in.get('sec-websocket-extensions')
        if not extensions_header:
            self._request.ws_requested_extensions = None
            return

        try:
            self._request.ws_requested_extensions = common.parse_extensions(
                extensions_header)
        except common.ExtensionParsingException as e:
            raise HandshakeException(
                'Failed to parse sec-websocket-extensions header: %r' % e)

        self._logger.debug(
            'Extensions requested: %r',
            list(
                map(common.ExtensionParameter.name,
                    self._request.ws_requested_extensions)))
Esempio n. 14
0
    def _parse_extensions(self):
        extensions_header = self._request.headers_in.get(
            common.SEC_WEBSOCKET_EXTENSIONS_HEADER)
        if not extensions_header:
            self._request.ws_requested_extensions = None
            return

        try:
            self._request.ws_requested_extensions = common.parse_extensions(
                extensions_header)
        except common.ExtensionParsingException as e:
            raise HandshakeException(
                'Failed to parse Sec-WebSocket-Extensions header: %r' % e)

        self._logger.debug(
            'Extensions requested: %r',
            list(
                map(common.ExtensionParameter.name,
                    self._request.ws_requested_extensions)))
Esempio n. 15
0
def _parse_compression_method(data):
    """Parses the value of "method" extension parameter."""

    return common.parse_extensions(data, allow_quoted_string=True)
Esempio n. 16
0
class ClientHandshakeProcessor(ClientHandshakeBase):
    """WebSocket opening handshake processor for
    draft-ietf-hybi-thewebsocketprotocol-06 and later.
    """

    def __init__(self, socket, options):
        super(ClientHandshakeProcessor, self).__init__()

        self._socket = socket
        self._options = options

        self._logger = util.get_class_logger(self)

    def handshake(self):
        """Performs opening handshake on the specified socket.

        Raises:
            ClientHandshakeError: handshake failed.
        """

        request_line = _build_method_line(self._options.resource)
        self._logger.debug('Client\'s opening handshake Request-Line: %r',
                           request_line)
        self._socket.sendall(request_line)

        fields = []
        fields.append(_format_host_header(
            self._options.server_host,
            self._options.server_port,
            self._options.use_tls))
        fields.append(_UPGRADE_HEADER)
        fields.append(_CONNECTION_HEADER)
        if self._options.origin is not None:
            if self._options.protocol_version == _PROTOCOL_VERSION_HYBI08:
                fields.append(_origin_header(
                    common.SEC_WEBSOCKET_ORIGIN_HEADER,
                    self._options.origin))
            else:
                fields.append(_origin_header(common.ORIGIN_HEADER,
                                             self._options.origin))

        original_key = os.urandom(16)
        self._key = base64.b64encode(original_key)
        self._logger.debug(
            '%s: %r (%s)',
            common.SEC_WEBSOCKET_KEY_HEADER,
            self._key,
            util.hexify(original_key))
        fields.append(
            '%s: %s\r\n' % (common.SEC_WEBSOCKET_KEY_HEADER, self._key))

        if self._options.version_header > 0:
            fields.append('%s: %d\r\n' % (common.SEC_WEBSOCKET_VERSION_HEADER,
                                          self._options.version_header))
        elif self._options.protocol_version == _PROTOCOL_VERSION_HYBI08:
            fields.append('%s: %d\r\n' % (common.SEC_WEBSOCKET_VERSION_HEADER,
                                          common.VERSION_HYBI08))
        else:
            fields.append('%s: %d\r\n' % (common.SEC_WEBSOCKET_VERSION_HEADER,
                                          common.VERSION_HYBI_LATEST))

        extensions_to_request = []

        if self._options.deflate_frame:
            extensions_to_request.append(
                common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION))

        if self._options.use_permessage_deflate:
            extension = common.ExtensionParameter(
                    common.PERMESSAGE_DEFLATE_EXTENSION)
            # Accept the client_max_window_bits extension parameter by default.
            extension.add_parameter(
                    PerMessageDeflateExtensionProcessor.
                            _CLIENT_MAX_WINDOW_BITS_PARAM,
                    None)
            extensions_to_request.append(extension)

        if len(extensions_to_request) != 0:
            fields.append(
                '%s: %s\r\n' %
                (common.SEC_WEBSOCKET_EXTENSIONS_HEADER,
                 common.format_extensions(extensions_to_request)))

        for field in fields:
            self._socket.sendall(field)

        self._socket.sendall('\r\n')

        self._logger.debug('Sent client\'s opening handshake headers: %r',
                           fields)
        self._logger.debug('Start reading Status-Line')

        status_line = ''
        while True:
            ch = _receive_bytes(self._socket, 1)
            status_line += ch
            if ch == '\n':
                break

        m = re.match('HTTP/\\d+\.\\d+ (\\d\\d\\d) .*\r\n', status_line)
        if m is None:
            raise ClientHandshakeError(
                'Wrong status line format: %r' % status_line)
        status_code = m.group(1)
        if status_code != '101':
            self._logger.debug('Unexpected status code %s with following '
                               'headers: %r', status_code, self._read_fields())
            raise ClientHandshakeError(
                'Expected HTTP status code 101 but found %r' % status_code)

        self._logger.debug('Received valid Status-Line')
        self._logger.debug('Start reading headers until we see an empty line')

        fields = self._read_fields()

        ch = _receive_bytes(self._socket, 1)
        if ch != '\n':  # 0x0A
            raise ClientHandshakeError(
                'Expected LF but found %r while reading value %r for header '
                'name %r' % (ch, value, name))

        self._logger.debug('Received an empty line')
        self._logger.debug('Server\'s opening handshake headers: %r', fields)

        _validate_mandatory_header(
            fields,
            common.UPGRADE_HEADER,
            common.WEBSOCKET_UPGRADE_TYPE,
            False)

        _validate_mandatory_header(
            fields,
            common.CONNECTION_HEADER,
            common.UPGRADE_CONNECTION_TYPE,
            False)

        accept = _get_mandatory_header(
            fields, common.SEC_WEBSOCKET_ACCEPT_HEADER)

        # Validate
        try:
            binary_accept = base64.b64decode(accept)
        except TypeError, e:
            raise HandshakeError(
                'Illegal value for header %s: %r' %
                (common.SEC_WEBSOCKET_ACCEPT_HEADER, accept))

        if len(binary_accept) != 20:
            raise ClientHandshakeError(
                'Decoded value of %s is not 20-byte long' %
                common.SEC_WEBSOCKET_ACCEPT_HEADER)

        self._logger.debug(
            'Response for challenge : %r (%s)',
            accept, util.hexify(binary_accept))

        binary_expected_accept = util.sha1_hash(
            self._key + common.WEBSOCKET_ACCEPT_UUID).digest()
        expected_accept = base64.b64encode(binary_expected_accept)

        self._logger.debug(
            'Expected response for challenge: %r (%s)',
            expected_accept, util.hexify(binary_expected_accept))

        if accept != expected_accept:
            raise ClientHandshakeError(
                'Invalid %s header: %r (expected: %s)' %
                (common.SEC_WEBSOCKET_ACCEPT_HEADER, accept, expected_accept))

        deflate_frame_accepted = False
        permessage_deflate_accepted = False

        extensions_header = fields.get(
            common.SEC_WEBSOCKET_EXTENSIONS_HEADER.lower())
        accepted_extensions = []
        if extensions_header is not None and len(extensions_header) != 0:
            accepted_extensions = common.parse_extensions(extensions_header[0])

        # TODO(bashi): Support the new style perframe compression extension.
        for extension in accepted_extensions:
            extension_name = extension.name()
            if (extension_name == common.DEFLATE_FRAME_EXTENSION and
                self._options.deflate_frame):
                deflate_frame_accepted = True
                processor = DeflateFrameExtensionProcessor(extension)
                unused_extension_response = processor.get_extension_response()
                self._options.deflate_frame = processor
                continue
            elif (extension_name == common.PERMESSAGE_DEFLATE_EXTENSION and
                  self._options.use_permessage_deflate):
                permessage_deflate_accepted = True

                framer = _get_permessage_deflate_framer(extension)
                framer.set_compress_outgoing_enabled(True)
                self._options.use_permessage_deflate = framer
                continue

            raise ClientHandshakeError(
                'Unexpected extension %r' % extension_name)

        if (self._options.deflate_frame and not deflate_frame_accepted):
            raise ClientHandshakeError(
                'Requested %s, but the server rejected it' %
                common.DEFLATE_FRAME_EXTENSION)

        if (self._options.use_permessage_deflate and
            not permessage_deflate_accepted):
            raise ClientHandshakeError(
                    'Requested %s, but the server rejected it' %
                    common.PERMESSAGE_DEFLATE_EXTENSION)
Esempio n. 17
0
def _parse_compression_method(data):
    """Parses the value of "method" extension parameter."""

    return common.parse_extensions(data, allow_quoted_string=True)
Esempio n. 18
0
 def test_parse_quoted_data(self):
     for formatted_string, definition in _TEST_QUOTED_EXTENSION_DATA:
         self._verify_extension_list(
             definition, parse_extensions(formatted_string))
Esempio n. 19
0
 def test_parse_redundant_quoted_data(self):
     for (formatted_string,
          definition) in _TEST_REDUNDANT_QUOTED_EXTENSION_DATA:
         self._verify_extension_list(
             definition, parse_extensions(formatted_string))
Esempio n. 20
0
def _parse_compression_method(data):
    """Parses the value of "method" extension parameter."""

    return common.parse_extensions(data)
Esempio n. 21
0
def _parse_compression_method(data):
    """Parses the value of "method" extension parameter."""

    return common.parse_extensions(data)
Esempio n. 22
0
 def test_parse(self):
     for formatted_string, definition in _TEST_TOKEN_EXTENSION_DATA:
         self._verify_extension_list(definition,
                                     parse_extensions(formatted_string))
Esempio n. 23
0
 def test_parse_quoted_data(self):
     for formatted_string, definition in _TEST_QUOTED_EXTENSION_DATA:
         self._verify_extension_list(definition,
                                     parse_extensions(formatted_string))
Esempio n. 24
0
 def test_parse_redundant_quoted_data(self):
     for (formatted_string,
          definition) in _TEST_REDUNDANT_QUOTED_EXTENSION_DATA:
         self._verify_extension_list(definition,
                                     parse_extensions(formatted_string))
Esempio n. 25
0
 def test_parse(self):
     for formatted_string, definition in _TEST_TOKEN_EXTENSION_DATA:
         self._verify_extension_list(
             definition, parse_extensions(formatted_string))
Esempio n. 26
0
    def handshake(self, socket):
        """Handshake WebSocket.

        Raises:
            Exception: handshake failed.
        """

        self._socket = socket

        request_line = _method_line(self._options.resource)
        self._logger.debug('Opening handshake Request-Line: %r', request_line)
        self._socket.sendall(request_line.encode('UTF-8'))

        fields = []
        fields.append(_UPGRADE_HEADER)
        fields.append(_CONNECTION_HEADER)

        fields.append(
            _format_host_header(self._options.server_host,
                                self._options.server_port,
                                self._options.use_tls))

        if self._options.version is 8:
            fields.append(_sec_origin_header(self._options.origin))
        else:
            fields.append(_origin_header(self._options.origin))

        original_key = os.urandom(16)
        key = base64.b64encode(original_key)
        self._logger.debug('Sec-WebSocket-Key: %s (%s)', key,
                           util.hexify(original_key))
        fields.append(u'Sec-WebSocket-Key: %s\r\n' % key.decode('UTF-8'))

        fields.append(u'Sec-WebSocket-Version: %d\r\n' % self._options.version)

        # Setting up extensions.
        if len(self._options.extensions) > 0:
            fields.append(u'Sec-WebSocket-Extensions: %s\r\n' %
                          ', '.join(self._options.extensions))

        self._logger.debug('Opening handshake request headers: %r', fields)

        for field in fields:
            self._socket.sendall(field.encode('UTF-8'))
        self._socket.sendall(b'\r\n')

        self._logger.info('Sent opening handshake request')

        field = b''
        while True:
            ch = receive_bytes(self._socket, 1)
            field += ch
            if ch == b'\n':
                break

        self._logger.debug('Opening handshake Response-Line: %r', field)

        # Will raise a UnicodeDecodeError when the decode fails
        if len(field) < 7 or not field.endswith(b'\r\n'):
            raise Exception('Wrong status line: %s' % field.decode('Latin-1'))
        m = re.match(b'[^ ]* ([^ ]*) .*', field)
        if m is None:
            raise Exception('No HTTP status code found in status line: %s' %
                            field.decode('Latin-1'))
        code = m.group(1)
        if not re.match(b'[0-9][0-9][0-9]$', code):
            raise Exception(
                'HTTP status code %s is not three digit in status line: %s' %
                (code.decode('Latin-1'), field.decode('Latin-1')))
        if code != b'101':
            raise HttpStatusException(
                'Expected HTTP status code 101 but found %s in status line: '
                '%r' % (code.decode('Latin-1'), field.decode('Latin-1')),
                int(code))
        fields = _read_fields(self._socket)
        ch = receive_bytes(self._socket, 1)
        if ch != b'\n':  # 0x0A
            raise Exception('Expected LF but found: %r' % ch)

        self._logger.debug('Opening handshake response headers: %r', fields)

        # Check /fields/
        if len(fields['upgrade']) != 1:
            raise Exception('Multiple Upgrade headers found: %s' %
                            fields['upgrade'])
        if len(fields['connection']) != 1:
            raise Exception('Multiple Connection headers found: %s' %
                            fields['connection'])
        if fields['upgrade'][0] != 'websocket':
            raise Exception('Unexpected Upgrade header value: %s' %
                            fields['upgrade'][0])
        if fields['connection'][0].lower() != 'upgrade':
            raise Exception('Unexpected Connection header value: %s' %
                            fields['connection'][0])

        if len(fields['sec-websocket-accept']) != 1:
            raise Exception('Multiple Sec-WebSocket-Accept headers found: %s' %
                            fields['sec-websocket-accept'])

        accept = fields['sec-websocket-accept'][0]

        # Validate
        try:
            decoded_accept = base64.b64decode(accept)
        except TypeError as e:
            raise HandshakeException(
                'Illegal value for header Sec-WebSocket-Accept: ' + accept)

        if len(decoded_accept) != 20:
            raise HandshakeException(
                'Decoded value of Sec-WebSocket-Accept is not 20-byte long')

        self._logger.debug('Actual Sec-WebSocket-Accept: %r (%s)', accept,
                           util.hexify(decoded_accept))

        original_expected_accept = sha1(key + WEBSOCKET_ACCEPT_UUID).digest()
        expected_accept = base64.b64encode(original_expected_accept)

        self._logger.debug('Expected Sec-WebSocket-Accept: %r (%s)',
                           expected_accept,
                           util.hexify(original_expected_accept))

        if accept != expected_accept.decode('UTF-8'):
            raise Exception(
                'Invalid Sec-WebSocket-Accept header: %r (expected) != %r '
                '(actual)' % (accept, expected_accept))

        server_extensions_header = fields.get('sec-websocket-extensions')
        accepted_extensions = []
        if server_extensions_header is not None:
            accepted_extensions = common.parse_extensions(
                ', '.join(server_extensions_header))

        # Scan accepted extension list to check if there is any unrecognized
        # extensions or extensions we didn't request in it. Then, for
        # extensions we request, parse them and store parameters. They will be
        # used later by each extension.
        for extension in accepted_extensions:
            if extension.name() == _PERMESSAGE_DEFLATE_EXTENSION:
                checker = self._options.check_permessage_deflate
                if checker:
                    checker(extension)
                    continue

            raise Exception('Received unrecognized extension: %s' %
                            extension.name())