def test_parse_with_allow_quoted_string(self): for formatted_string, definition in _TEST_TOKEN_EXTENSION_DATA: self._verify_extension_list( definition, parse_extensions(formatted_string, allow_quoted_string=True)) for formatted_string, definition in _TEST_QUOTED_EXTENSION_DATA: self._verify_extension_list( definition, parse_extensions(formatted_string, allow_quoted_string=True))
def test_parse_redundant_data_with_allow_quoted_string(self): for (formatted_string, definition) in _TEST_REDUNDANT_TOKEN_EXTENSION_DATA: self._verify_extension_list( definition, parse_extensions(formatted_string, allow_quoted_string=True)) for (formatted_string, definition) in _TEST_REDUNDANT_QUOTED_EXTENSION_DATA: self._verify_extension_list( definition, parse_extensions(formatted_string, allow_quoted_string=True))
def test_parse(self): for formatted_string, definition in _TEST_TOKEN_EXTENSION_DATA: self._verify_extension_list( definition, parse_extensions(formatted_string, allow_quoted_string=False)) for formatted_string, unused_definition in _TEST_QUOTED_EXTENSION_DATA: self.assertRaises(ExtensionParsingException, parse_extensions, formatted_string, False)
def test_parse(self): for formatted_string, definition in _TEST_TOKEN_EXTENSION_DATA: self._verify_extension_list( definition, parse_extensions(formatted_string, allow_quoted_string=False)) for formatted_string, unused_definition in _TEST_QUOTED_EXTENSION_DATA: self.assertRaises( ExtensionParsingException, parse_extensions, formatted_string, False)
def test_parse_redundant_data(self): for (formatted_string, definition) in _TEST_REDUNDANT_TOKEN_EXTENSION_DATA: self._verify_extension_list( definition, parse_extensions(formatted_string, allow_quoted_string=False)) for (formatted_string, definition) in _TEST_REDUNDANT_QUOTED_EXTENSION_DATA: self.assertRaises(ExtensionParsingException, parse_extensions, formatted_string, False)
def test_parse_redundant_data(self): for (formatted_string, definition) in _TEST_REDUNDANT_TOKEN_EXTENSION_DATA: self._verify_extension_list( definition, parse_extensions(formatted_string, allow_quoted_string=False)) for (formatted_string, definition) in _TEST_REDUNDANT_QUOTED_EXTENSION_DATA: self.assertRaises( ExtensionParsingException, parse_extensions, formatted_string, False)
def _parse_extensions(self): extensions_header = self._request.headers_in.get( common.SEC_WEBSOCKET_EXTENSIONS_HEADER) if not extensions_header: self._request.ws_requested_extensions = None return try: self._request.ws_requested_extensions = common.parse_extensions( extensions_header) except common.ExtensionParsingException, e: raise HandshakeException( 'Failed to parse Sec-WebSocket-Extensions header: %r' % e)
def _parse_extensions(self): extensions_header = self._request.headers_in.get( common.SEC_WEBSOCKET_EXTENSIONS_HEADER) if not extensions_header: self._request.ws_requested_extensions = None return if self._request.ws_version is common.VERSION_HYBI08: allow_quoted_string=False else: allow_quoted_string=True try: self._request.ws_requested_extensions = common.parse_extensions( extensions_header, allow_quoted_string=allow_quoted_string) except common.ExtensionParsingException, e: raise HandshakeException( 'Failed to parse Sec-WebSocket-Extensions header: %r' % e)
def _parse_extensions(self): extensions_header = self._request.headers_in.get( common.SEC_WEBSOCKET_EXTENSIONS_HEADER) if not extensions_header: self._request.ws_requested_extensions = None return if self._request.ws_version is common.VERSION_HYBI08: allow_quoted_string = False else: allow_quoted_string = True try: self._request.ws_requested_extensions = common.parse_extensions( extensions_header, allow_quoted_string=allow_quoted_string) except common.ExtensionParsingException, e: raise HandshakeException( 'Failed to parse Sec-WebSocket-Extensions header: %r' % e)
def _parse_extensions(self): extensions_header = self._request.headers_in.get('sec-websocket-extensions') if not extensions_header: self._request.ws_requested_extensions = None return try: self._request.ws_requested_extensions = common.parse_extensions( extensions_header) except common.ExtensionParsingException as e: raise HandshakeException( 'Failed to parse sec-websocket-extensions header: %r' % e) self._logger.debug( 'Extensions requested: %r', list( map(common.ExtensionParameter.name, self._request.ws_requested_extensions)))
def _parse_extensions(self): extensions_header = self._request.headers_in.get( common.SEC_WEBSOCKET_EXTENSIONS_HEADER) if not extensions_header: self._request.ws_requested_extensions = None return try: self._request.ws_requested_extensions = common.parse_extensions( extensions_header) except common.ExtensionParsingException as e: raise HandshakeException( 'Failed to parse Sec-WebSocket-Extensions header: %r' % e) self._logger.debug( 'Extensions requested: %r', list( map(common.ExtensionParameter.name, self._request.ws_requested_extensions)))
def _parse_compression_method(data): """Parses the value of "method" extension parameter.""" return common.parse_extensions(data, allow_quoted_string=True)
class ClientHandshakeProcessor(ClientHandshakeBase): """WebSocket opening handshake processor for draft-ietf-hybi-thewebsocketprotocol-06 and later. """ def __init__(self, socket, options): super(ClientHandshakeProcessor, self).__init__() self._socket = socket self._options = options self._logger = util.get_class_logger(self) def handshake(self): """Performs opening handshake on the specified socket. Raises: ClientHandshakeError: handshake failed. """ request_line = _build_method_line(self._options.resource) self._logger.debug('Client\'s opening handshake Request-Line: %r', request_line) self._socket.sendall(request_line) fields = [] fields.append(_format_host_header( self._options.server_host, self._options.server_port, self._options.use_tls)) fields.append(_UPGRADE_HEADER) fields.append(_CONNECTION_HEADER) if self._options.origin is not None: if self._options.protocol_version == _PROTOCOL_VERSION_HYBI08: fields.append(_origin_header( common.SEC_WEBSOCKET_ORIGIN_HEADER, self._options.origin)) else: fields.append(_origin_header(common.ORIGIN_HEADER, self._options.origin)) original_key = os.urandom(16) self._key = base64.b64encode(original_key) self._logger.debug( '%s: %r (%s)', common.SEC_WEBSOCKET_KEY_HEADER, self._key, util.hexify(original_key)) fields.append( '%s: %s\r\n' % (common.SEC_WEBSOCKET_KEY_HEADER, self._key)) if self._options.version_header > 0: fields.append('%s: %d\r\n' % (common.SEC_WEBSOCKET_VERSION_HEADER, self._options.version_header)) elif self._options.protocol_version == _PROTOCOL_VERSION_HYBI08: fields.append('%s: %d\r\n' % (common.SEC_WEBSOCKET_VERSION_HEADER, common.VERSION_HYBI08)) else: fields.append('%s: %d\r\n' % (common.SEC_WEBSOCKET_VERSION_HEADER, common.VERSION_HYBI_LATEST)) extensions_to_request = [] if self._options.deflate_frame: extensions_to_request.append( common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)) if self._options.use_permessage_deflate: extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) # Accept the client_max_window_bits extension parameter by default. extension.add_parameter( PerMessageDeflateExtensionProcessor. _CLIENT_MAX_WINDOW_BITS_PARAM, None) extensions_to_request.append(extension) if len(extensions_to_request) != 0: fields.append( '%s: %s\r\n' % (common.SEC_WEBSOCKET_EXTENSIONS_HEADER, common.format_extensions(extensions_to_request))) for field in fields: self._socket.sendall(field) self._socket.sendall('\r\n') self._logger.debug('Sent client\'s opening handshake headers: %r', fields) self._logger.debug('Start reading Status-Line') status_line = '' while True: ch = _receive_bytes(self._socket, 1) status_line += ch if ch == '\n': break m = re.match('HTTP/\\d+\.\\d+ (\\d\\d\\d) .*\r\n', status_line) if m is None: raise ClientHandshakeError( 'Wrong status line format: %r' % status_line) status_code = m.group(1) if status_code != '101': self._logger.debug('Unexpected status code %s with following ' 'headers: %r', status_code, self._read_fields()) raise ClientHandshakeError( 'Expected HTTP status code 101 but found %r' % status_code) self._logger.debug('Received valid Status-Line') self._logger.debug('Start reading headers until we see an empty line') fields = self._read_fields() ch = _receive_bytes(self._socket, 1) if ch != '\n': # 0x0A raise ClientHandshakeError( 'Expected LF but found %r while reading value %r for header ' 'name %r' % (ch, value, name)) self._logger.debug('Received an empty line') self._logger.debug('Server\'s opening handshake headers: %r', fields) _validate_mandatory_header( fields, common.UPGRADE_HEADER, common.WEBSOCKET_UPGRADE_TYPE, False) _validate_mandatory_header( fields, common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE, False) accept = _get_mandatory_header( fields, common.SEC_WEBSOCKET_ACCEPT_HEADER) # Validate try: binary_accept = base64.b64decode(accept) except TypeError, e: raise HandshakeError( 'Illegal value for header %s: %r' % (common.SEC_WEBSOCKET_ACCEPT_HEADER, accept)) if len(binary_accept) != 20: raise ClientHandshakeError( 'Decoded value of %s is not 20-byte long' % common.SEC_WEBSOCKET_ACCEPT_HEADER) self._logger.debug( 'Response for challenge : %r (%s)', accept, util.hexify(binary_accept)) binary_expected_accept = util.sha1_hash( self._key + common.WEBSOCKET_ACCEPT_UUID).digest() expected_accept = base64.b64encode(binary_expected_accept) self._logger.debug( 'Expected response for challenge: %r (%s)', expected_accept, util.hexify(binary_expected_accept)) if accept != expected_accept: raise ClientHandshakeError( 'Invalid %s header: %r (expected: %s)' % (common.SEC_WEBSOCKET_ACCEPT_HEADER, accept, expected_accept)) deflate_frame_accepted = False permessage_deflate_accepted = False extensions_header = fields.get( common.SEC_WEBSOCKET_EXTENSIONS_HEADER.lower()) accepted_extensions = [] if extensions_header is not None and len(extensions_header) != 0: accepted_extensions = common.parse_extensions(extensions_header[0]) # TODO(bashi): Support the new style perframe compression extension. for extension in accepted_extensions: extension_name = extension.name() if (extension_name == common.DEFLATE_FRAME_EXTENSION and self._options.deflate_frame): deflate_frame_accepted = True processor = DeflateFrameExtensionProcessor(extension) unused_extension_response = processor.get_extension_response() self._options.deflate_frame = processor continue elif (extension_name == common.PERMESSAGE_DEFLATE_EXTENSION and self._options.use_permessage_deflate): permessage_deflate_accepted = True framer = _get_permessage_deflate_framer(extension) framer.set_compress_outgoing_enabled(True) self._options.use_permessage_deflate = framer continue raise ClientHandshakeError( 'Unexpected extension %r' % extension_name) if (self._options.deflate_frame and not deflate_frame_accepted): raise ClientHandshakeError( 'Requested %s, but the server rejected it' % common.DEFLATE_FRAME_EXTENSION) if (self._options.use_permessage_deflate and not permessage_deflate_accepted): raise ClientHandshakeError( 'Requested %s, but the server rejected it' % common.PERMESSAGE_DEFLATE_EXTENSION)
def test_parse_quoted_data(self): for formatted_string, definition in _TEST_QUOTED_EXTENSION_DATA: self._verify_extension_list( definition, parse_extensions(formatted_string))
def test_parse_redundant_quoted_data(self): for (formatted_string, definition) in _TEST_REDUNDANT_QUOTED_EXTENSION_DATA: self._verify_extension_list( definition, parse_extensions(formatted_string))
def _parse_compression_method(data): """Parses the value of "method" extension parameter.""" return common.parse_extensions(data)
def test_parse(self): for formatted_string, definition in _TEST_TOKEN_EXTENSION_DATA: self._verify_extension_list(definition, parse_extensions(formatted_string))
def test_parse_quoted_data(self): for formatted_string, definition in _TEST_QUOTED_EXTENSION_DATA: self._verify_extension_list(definition, parse_extensions(formatted_string))
def test_parse_redundant_quoted_data(self): for (formatted_string, definition) in _TEST_REDUNDANT_QUOTED_EXTENSION_DATA: self._verify_extension_list(definition, parse_extensions(formatted_string))
def test_parse(self): for formatted_string, definition in _TEST_TOKEN_EXTENSION_DATA: self._verify_extension_list( definition, parse_extensions(formatted_string))
def handshake(self, socket): """Handshake WebSocket. Raises: Exception: handshake failed. """ self._socket = socket request_line = _method_line(self._options.resource) self._logger.debug('Opening handshake Request-Line: %r', request_line) self._socket.sendall(request_line.encode('UTF-8')) fields = [] fields.append(_UPGRADE_HEADER) fields.append(_CONNECTION_HEADER) fields.append( _format_host_header(self._options.server_host, self._options.server_port, self._options.use_tls)) if self._options.version is 8: fields.append(_sec_origin_header(self._options.origin)) else: fields.append(_origin_header(self._options.origin)) original_key = os.urandom(16) key = base64.b64encode(original_key) self._logger.debug('Sec-WebSocket-Key: %s (%s)', key, util.hexify(original_key)) fields.append(u'Sec-WebSocket-Key: %s\r\n' % key.decode('UTF-8')) fields.append(u'Sec-WebSocket-Version: %d\r\n' % self._options.version) # Setting up extensions. if len(self._options.extensions) > 0: fields.append(u'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(self._options.extensions)) self._logger.debug('Opening handshake request headers: %r', fields) for field in fields: self._socket.sendall(field.encode('UTF-8')) self._socket.sendall(b'\r\n') self._logger.info('Sent opening handshake request') field = b'' while True: ch = receive_bytes(self._socket, 1) field += ch if ch == b'\n': break self._logger.debug('Opening handshake Response-Line: %r', field) # Will raise a UnicodeDecodeError when the decode fails if len(field) < 7 or not field.endswith(b'\r\n'): raise Exception('Wrong status line: %s' % field.decode('Latin-1')) m = re.match(b'[^ ]* ([^ ]*) .*', field) if m is None: raise Exception('No HTTP status code found in status line: %s' % field.decode('Latin-1')) code = m.group(1) if not re.match(b'[0-9][0-9][0-9]$', code): raise Exception( 'HTTP status code %s is not three digit in status line: %s' % (code.decode('Latin-1'), field.decode('Latin-1'))) if code != b'101': raise HttpStatusException( 'Expected HTTP status code 101 but found %s in status line: ' '%r' % (code.decode('Latin-1'), field.decode('Latin-1')), int(code)) fields = _read_fields(self._socket) ch = receive_bytes(self._socket, 1) if ch != b'\n': # 0x0A raise Exception('Expected LF but found: %r' % ch) self._logger.debug('Opening handshake response headers: %r', fields) # Check /fields/ if len(fields['upgrade']) != 1: raise Exception('Multiple Upgrade headers found: %s' % fields['upgrade']) if len(fields['connection']) != 1: raise Exception('Multiple Connection headers found: %s' % fields['connection']) if fields['upgrade'][0] != 'websocket': raise Exception('Unexpected Upgrade header value: %s' % fields['upgrade'][0]) if fields['connection'][0].lower() != 'upgrade': raise Exception('Unexpected Connection header value: %s' % fields['connection'][0]) if len(fields['sec-websocket-accept']) != 1: raise Exception('Multiple Sec-WebSocket-Accept headers found: %s' % fields['sec-websocket-accept']) accept = fields['sec-websocket-accept'][0] # Validate try: decoded_accept = base64.b64decode(accept) except TypeError as e: raise HandshakeException( 'Illegal value for header Sec-WebSocket-Accept: ' + accept) if len(decoded_accept) != 20: raise HandshakeException( 'Decoded value of Sec-WebSocket-Accept is not 20-byte long') self._logger.debug('Actual Sec-WebSocket-Accept: %r (%s)', accept, util.hexify(decoded_accept)) original_expected_accept = sha1(key + WEBSOCKET_ACCEPT_UUID).digest() expected_accept = base64.b64encode(original_expected_accept) self._logger.debug('Expected Sec-WebSocket-Accept: %r (%s)', expected_accept, util.hexify(original_expected_accept)) if accept != expected_accept.decode('UTF-8'): raise Exception( 'Invalid Sec-WebSocket-Accept header: %r (expected) != %r ' '(actual)' % (accept, expected_accept)) server_extensions_header = fields.get('sec-websocket-extensions') accepted_extensions = [] if server_extensions_header is not None: accepted_extensions = common.parse_extensions( ', '.join(server_extensions_header)) # Scan accepted extension list to check if there is any unrecognized # extensions or extensions we didn't request in it. Then, for # extensions we request, parse them and store parameters. They will be # used later by each extension. for extension in accepted_extensions: if extension.name() == _PERMESSAGE_DEFLATE_EXTENSION: checker = self._options.check_permessage_deflate if checker: checker(extension) continue raise Exception('Received unrecognized extension: %s' % extension.name())