def test_deflate_frame_bad_request_parameters(self): """Tests that if there's anything wrong with deflate-frame extension request, deflate-frame is rejected. """ extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION) # max_window_bits less than 8 is illegal. extension.add_parameter('max_window_bits', '7') processor = DeflateFrameExtensionProcessor(extension) self.assertEqual(None, processor.get_extension_response()) extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION) # max_window_bits greater than 15 is illegal. extension.add_parameter('max_window_bits', '16') processor = DeflateFrameExtensionProcessor(extension) self.assertEqual(None, processor.get_extension_response()) extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION) # Non integer max_window_bits is illegal. extension.add_parameter('max_window_bits', 'foobar') processor = DeflateFrameExtensionProcessor(extension) self.assertEqual(None, processor.get_extension_response()) extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION) # no_context_takeover must not have any value. extension.add_parameter('no_context_takeover', 'foobar') processor = DeflateFrameExtensionProcessor(extension) self.assertEqual(None, processor.get_extension_response())
def test_registry(self): processor = extensions.get_extension_processor( common.ExtensionParameter('deflate-frame')) self.assertIsInstance(processor, extensions.DeflateFrameExtensionProcessor) processor = extensions.get_extension_processor( common.ExtensionParameter('x-webkit-deflate-frame')) self.assertIsInstance(processor, extensions.DeflateFrameExtensionProcessor)
def test_deflate_frame_response_parameters(self): extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION) processor = DeflateFrameExtensionProcessor(extension) processor.set_response_window_bits(8) response = processor.get_extension_response() self.assertTrue(response.has_parameter('max_window_bits')) self.assertEqual('8', response.get_parameter_value('max_window_bits')) extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION) processor = DeflateFrameExtensionProcessor(extension) processor.set_response_no_context_takeover(True) response = processor.get_extension_response() self.assertTrue(response.has_parameter('no_context_takeover')) self.assertTrue( response.get_parameter_value('no_context_takeover') is None)
def test_response_with_max_window_bits_without_client_permission(self): processor = extensions.PerMessageDeflateExtensionProcessor( common.ExtensionParameter('permessage-deflate')) processor.set_client_max_window_bits(10) response = processor.get_extension_response() self.assertIsNone(response)
def test_receive_message_deflate_frame_client_using_smaller_window(self): """Test that frames coming from a client which is using smaller window size that the server are correctly received. """ # Using the smallest window bits of 8 for generating input frames. compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -8) data = '' # Use a frame whose content is bigger than the clients' DEFLATE window # size before compression. The content mainly consists of 'a' but # repetition of 'b' is put at the head and tail so that if the window # size is big, the head is back-referenced but if small, not. payload = 'b' * 64 + 'a' * 1024 + 'b' * 64 compressed_hello = compress.compress(payload) compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) compressed_hello = compressed_hello[:-4] data += '\xc1%c' % (len(compressed_hello) | 0x80) data += _mask_hybi(compressed_hello) # Close frame data += '\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + 'Good bye') extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION) request = _create_request_from_rawdata(data, deflate_frame_request=extension) self.assertEqual(payload, msgutil.receive_message(request)) self.assertEqual(None, msgutil.receive_message(request))
def test_receive_message_random_section(self): """Test that a compressed message fragmented into lots of chunks is correctly received. """ random.seed(a=0) payload = b''.join( [struct.pack('!B', random.randint(0, 255)) for i in range(1000)]) compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) compressed_payload = compress.compress(payload) compressed_payload += compress.flush(zlib.Z_SYNC_FLUSH) compressed_payload = compressed_payload[:-4] # Fragment the compressed payload into lots of frames. bytes_chunked = 0 data = b'' frame_count = 0 chunk_sizes = [] while bytes_chunked < len(compressed_payload): # Make sure that # - the length of chunks are equal or less than 125 so that we can # use 1 octet length header format for all frames. # - at least 10 chunks are created. chunk_size = random.randint( 1, min(125, len(compressed_payload) // 10, len(compressed_payload) - bytes_chunked)) chunk_sizes.append(chunk_size) chunk = compressed_payload[bytes_chunked:bytes_chunked + chunk_size] bytes_chunked += chunk_size first_octet = 0x00 if len(data) == 0: first_octet = first_octet | 0x42 if bytes_chunked == len(compressed_payload): first_octet = first_octet | 0x80 data += b'%c%c' % (first_octet, chunk_size | 0x80) data += _mask_hybi(chunk) frame_count += 1 self.assertTrue(len(chunk_sizes) > 10) # Close frame data += b'\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + b'Good bye') extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) request = _create_request_from_rawdata( data, permessage_deflate_request=extension) self.assertEqual(payload, msgutil.receive_message(request)) self.assertEqual(None, msgutil.receive_message(request))
def test_send_message_permessage_compress_deflate(self): compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) extension = common.ExtensionParameter( common.PERMESSAGE_COMPRESSION_EXTENSION) extension.add_parameter('method', 'deflate') request = _create_request_from_rawdata( '', permessage_compression_request=extension) msgutil.send_message(request, 'Hello') msgutil.send_message(request, 'World') expected = '' compressed_hello = compress.compress('Hello') compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) compressed_hello = compressed_hello[:-4] expected += '\xc1%c' % len(compressed_hello) expected += compressed_hello compressed_world = compress.compress('World') compressed_world += compress.flush(zlib.Z_SYNC_FLUSH) compressed_world = compressed_world[:-4] expected += '\xc1%c' % len(compressed_world) expected += compressed_world self.assertEqual(expected, request.connection.written_data())
def test_send_message_fragmented_bfinal(self): extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) request = _create_request_from_rawdata( b'', permessage_deflate_request=extension) self.assertEqual(1, len(request.ws_extension_processors)) request.ws_extension_processors[0].set_bfinal(True) msgutil.send_message(request, 'Hello', end=False) msgutil.send_message(request, 'World', end=True) expected = b'' compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) compressed_hello = compress.compress(b'Hello') compressed_hello += compress.flush(zlib.Z_FINISH) compressed_hello = compressed_hello + struct.pack('!B', 0) expected += b'\x41%c' % len(compressed_hello) expected += compressed_hello compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) compressed_world = compress.compress(b'World') compressed_world += compress.flush(zlib.Z_FINISH) compressed_world = compressed_world + struct.pack('!B', 0) expected += b'\x80%c' % len(compressed_world) expected += compressed_world self.assertEqual(expected, request.connection.written_data())
def test_send_message_deflate_frame_comp_bit(self): compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION) request = _create_request_from_rawdata('', deflate_frame_request=extension) self.assertEquals(1, len(request.ws_extension_processors)) deflate_frame_processor = request.ws_extension_processors[0] msgutil.send_message(request, 'Hello') deflate_frame_processor.disable_outgoing_compression() msgutil.send_message(request, 'Hello') deflate_frame_processor.enable_outgoing_compression() msgutil.send_message(request, 'Hello') expected = '' compressed_hello = compress.compress('Hello') compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) compressed_hello = compressed_hello[:-4] expected += '\xc1%c' % len(compressed_hello) expected += compressed_hello expected += '\x81\x05Hello' compressed_2nd_hello = compress.compress('Hello') compressed_2nd_hello += compress.flush(zlib.Z_SYNC_FLUSH) compressed_2nd_hello = compressed_2nd_hello[:-4] expected += '\xc1%c' % len(compressed_2nd_hello) expected += compressed_2nd_hello self.assertEqual(expected, request.connection.written_data())
def test_create_method_desc_with_parameters(self): params = common.ExtensionParameter('foo') params.add_parameter('x', 'Hello, World') params.add_parameter('y', '10') desc = extensions._create_accepted_method_desc('foo', params.get_parameters()) self.assertEqual('foo; x="Hello, World"; y=10', desc)
def test_send_message_no_context_takeover_parameter(self): extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) extension.add_parameter('server_no_context_takeover', None) request = _create_request_from_rawdata( b'', permessage_deflate_request=extension) for i in range(3): msgutil.send_message(request, 'Hello', end=False) msgutil.send_message(request, 'Hello', end=True) compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) first_hello = compress.compress(b'Hello') first_hello += compress.flush(zlib.Z_SYNC_FLUSH) expected = b'\x41%c' % len(first_hello) expected += first_hello second_hello = compress.compress(b'Hello') second_hello += compress.flush(zlib.Z_SYNC_FLUSH) second_hello = second_hello[:-4] expected += b'\x80%c' % len(second_hello) expected += second_hello self.assertEqual(expected + expected + expected, request.connection.written_data())
def test_send_message_using_small_window(self): common_part = 'abcdefghijklmnopqrstuvwxyz' test_message = common_part + '-' * 30000 + common_part extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) extension.add_parameter('server_max_window_bits', '8') request = _create_request_from_rawdata( b'', permessage_deflate_request=extension) msgutil.send_message(request, test_message) expected_websocket_header_size = 2 expected_websocket_payload_size = 91 actual_frame = request.connection.written_data() self.assertEqual( expected_websocket_header_size + expected_websocket_payload_size, len(actual_frame)) actual_header = actual_frame[0:expected_websocket_header_size] actual_payload = actual_frame[expected_websocket_header_size:] self.assertEqual(b'\xc1%c' % expected_websocket_payload_size, actual_header) decompress = zlib.decompressobj(-8) decompressed_message = decompress.decompress(actual_payload + b'\x00\x00\xff\xff') decompressed_message += decompress.flush() self.assertEqual(test_message, decompressed_message.decode('UTF-8')) self.assertEqual(0, len(decompress.unused_data)) self.assertEqual(0, len(decompress.unconsumed_tail))
def test_send_message_fragmented(self): extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) request = _create_request_from_rawdata( b'', permessage_deflate_request=extension) msgutil.send_message(request, 'Hello', end=False) msgutil.send_message(request, 'Goodbye', end=False) msgutil.send_message(request, 'World') compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) compressed_hello = compress.compress(b'Hello') compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) expected = b'\x41%c' % len(compressed_hello) expected += compressed_hello compressed_goodbye = compress.compress(b'Goodbye') compressed_goodbye += compress.flush(zlib.Z_SYNC_FLUSH) expected += b'\x00%c' % len(compressed_goodbye) expected += compressed_goodbye compressed_world = compress.compress(b'World') compressed_world += compress.flush(zlib.Z_SYNC_FLUSH) compressed_world = compressed_world[:-4] expected += b'\x80%c' % len(compressed_world) expected += compressed_world self.assertEqual(expected, request.connection.written_data())
def test_receive_message_deflate_frame_comp_bit(self): compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) data = '' compressed_hello = compress.compress('Hello') compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) compressed_hello = compressed_hello[:-4] data += '\xc1%c' % (len(compressed_hello) | 0x80) data += _mask_hybi(compressed_hello) data += '\x81\x85' + _mask_hybi('Hello') compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) compressed_2nd_hello = compress.compress('Hello') compressed_2nd_hello += compress.flush(zlib.Z_SYNC_FLUSH) compressed_2nd_hello = compressed_2nd_hello[:-4] data += '\xc1%c' % (len(compressed_2nd_hello) | 0x80) data += _mask_hybi(compressed_2nd_hello) extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION) request = _create_request_from_rawdata(data, deflate_frame_request=extension) for i in xrange(3): self.assertEqual('Hello', msgutil.receive_message(request))
def _create_accepted_method_desc(method_name, method_params): """Creates accepted-method-desc from given method name and parameters""" extension = common.ExtensionParameter(method_name) for name, value in method_params: extension.add_parameter(name, value) return common.format_extension(extension)
def get_extension_response(self): if len(self._request.get_parameter_names()) != 0: return None self._logger.debug('Enable %s extension', common.DEFLATE_STREAM_EXTENSION) return common.ExtensionParameter(common.DEFLATE_STREAM_EXTENSION)
def test_offer_with_unknown_parameter(self): parameter = common.ExtensionParameter('perframe-deflate') parameter.add_parameter('foo', 'bar') processor = extensions.DeflateFrameExtensionProcessor(parameter) response = processor.get_extension_response() self.assertEqual('perframe-deflate', response.name()) self.assertEqual(0, len(response.get_parameters()))
def test_response_with_false_for_no_context_takeover(self): processor = extensions.PerMessageDeflateExtensionProcessor( common.ExtensionParameter('permessage-deflate')) processor.set_client_no_context_takeover(False) response = processor.get_extension_response() self.assertEqual('permessage-deflate', response.name()) self.assertEqual(0, len(response.get_parameters()))
class MuxExtensionProcessor(ExtensionProcessorInterface): """WebSocket multiplexing extension processor.""" _QUOTA_PARAM = 'quota' def __init__(self, request): ExtensionProcessorInterface.__init__(self, request) self._quota = 0 self._extensions = [] def name(self): return common.MUX_EXTENSION def check_consistency_with_other_processors(self, processors): before_mux = True for processor in processors: name = processor.name() if name == self.name(): before_mux = False continue if not processor.is_active(): continue if before_mux: # Mux extension cannot be used after extensions # that depend on frame boundary, extension data field, or any # reserved bits which are attributed to each frame. if (name == common.PERFRAME_COMPRESSION_EXTENSION or name == common.DEFLATE_FRAME_EXTENSION or name == common.X_WEBKIT_DEFLATE_FRAME_EXTENSION): self.set_active(False) return else: # Mux extension should not be applied before any history-based # compression extension. if (name == common.PERFRAME_COMPRESSION_EXTENSION or name == common.DEFLATE_FRAME_EXTENSION or name == common.X_WEBKIT_DEFLATE_FRAME_EXTENSION or name == common.PERMESSAGE_COMPRESSION_EXTENSION or name == common.X_WEBKIT_PERMESSAGE_COMPRESSION_EXTENSION): self.set_active(False) return def _get_extension_response_internal(self): self._active = False quota = self._request.get_parameter_value(self._QUOTA_PARAM) if quota is not None: try: quota = int(quota) except ValueError, e: return None if quota < 0 or quota >= 2**32: return None self._quota = quota self._active = True return common.ExtensionParameter(common.MUX_EXTENSION)
class DeflateFrameExtensionProcessor(ExtensionProcessorInterface): """WebSocket Per-frame DEFLATE extension processor.""" _WINDOW_BITS_PARAM = 'window_bits' _NO_CONTEXT_TAKEOVER_PARAM = 'no_context_takeover' def __init__(self, request): self._logger = util.get_class_logger(self) self._request = request self._response_window_bits = None self._response_no_context_takeover = False def get_extension_response(self): # Any unknown parameter will be just ignored. window_bits = self._request.get_parameter_value( self._WINDOW_BITS_PARAM) no_context_takeover = self._request.has_parameter( self._NO_CONTEXT_TAKEOVER_PARAM) if (no_context_takeover and self._request.get_parameter_value( self._NO_CONTEXT_TAKEOVER_PARAM) is not None): return None if window_bits is not None: try: window_bits = int(window_bits) except ValueError, e: return None if window_bits < 8 or window_bits > 15: return None self._deflater = util._RFC1979Deflater(window_bits, no_context_takeover) self._inflater = util._RFC1979Inflater() self._compress_outgoing = True response = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION) if self._response_window_bits is not None: response.add_parameter(self._WINDOW_BITS_PARAM, str(self._response_window_bits)) if self._response_no_context_takeover: response.add_parameter(self._NO_CONTEXT_TAKEOVER_PARAM, None) self._logger.debug( 'Enable %s extension (' 'request: window_bits=%s; no_context_takeover=%r, ' 'response: window_wbits=%s; no_context_takeover=%r)' % (common.DEFLATE_STREAM_EXTENSION, window_bits, no_context_takeover, self._response_window_bits, self._response_no_context_takeover)) return response
def test_offer_with_no_context_takeover(self): parameter = common.ExtensionParameter('perframe-deflate') parameter.add_parameter('no_context_takeover', None) processor = extensions.DeflateFrameExtensionProcessor(parameter) response = processor.get_extension_response() self.assertEqual('perframe-deflate', response.name()) self.assertEqual(0, len(response.get_parameters())) self.assertTrue(processor._rfc1979_deflater._no_context_takeover)
def test_offer_with_max_window_bits(self): parameter = common.ExtensionParameter('perframe-deflate') parameter.add_parameter('max_window_bits', '10') processor = extensions.DeflateFrameExtensionProcessor(parameter) response = processor.get_extension_response() self.assertEqual('perframe-deflate', response.name()) self.assertEqual(0, len(response.get_parameters())) self.assertEqual(10, processor._rfc1979_deflater._window_bits)
def test_response_with_max_window_bits(self): parameter = common.ExtensionParameter('permessage-deflate') parameter.add_parameter('client_max_window_bits', None) processor = extensions.PerMessageDeflateExtensionProcessor(parameter) processor.set_client_max_window_bits(10) response = processor.get_extension_response() self.assertEqual('permessage-deflate', response.name()) self.assertEqual([('client_max_window_bits', '10')], response.get_parameters())
def test_response_with_true_for_no_context_takeover(self): processor = extensions.PerMessageDeflateExtensionProcessor( common.ExtensionParameter('permessage-deflate')) processor.set_c2s_no_context_takeover(True) response = processor.get_extension_response() self.assertEqual('permessage-deflate', response.name()) self.assertEqual([('c2s_no_context_takeover', None)], response.get_parameters())
def test_minimal_offer(self): processor = extensions.PerMessageDeflateExtensionProcessor( common.ExtensionParameter('permessage-deflate')) response = processor.get_extension_response() self.assertEqual('permessage-deflate', response.name()) self.assertEqual(0, len(response.get_parameters())) self.assertEqual(zlib.MAX_WBITS, processor._rfc1979_deflater._window_bits) self.assertFalse(processor._rfc1979_deflater._no_context_takeover)
def test_offer_with_max_window_bits(self): parameter = common.ExtensionParameter('permessage-deflate') parameter.add_parameter('server_max_window_bits', '10') processor = extensions.PerMessageDeflateExtensionProcessor(parameter) response = processor.get_extension_response() self.assertEqual('permessage-deflate', response.name()) self.assertEqual([('server_max_window_bits', '10')], response.get_parameters()) self.assertEqual(10, processor._rfc1979_deflater._window_bits)
def _get_extension_response_internal(self): processor_response = self._get_compression_processor_response() if processor_response is None: return None response = common.ExtensionParameter(self._request.name()) accepted_method_desc = _create_accepted_method_desc( self._compression_method_name, processor_response.get_parameters()) response.add_parameter(self._METHOD_PARAM, accepted_method_desc) self._logger.debug( 'Enable %s extension (method: %s)' % (self._request.name(), self._compression_method_name)) return response
def _get_extension_response_internal(self): self._active = False quota = self._request.get_parameter_value(self._QUOTA_PARAM) if quota is not None: try: quota = int(quota) except ValueError as e: return None if quota < 0 or quota >= 2 ** 32: return None self._quota = quota self._active = True return common.ExtensionParameter(common.MUX_EXTENSION)
def test_response_parameters(self): extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) extension.add_parameter('server_no_context_takeover', None) processor = PerMessageDeflateExtensionProcessor(extension) response = processor.get_extension_response() self.assertTrue(response.has_parameter('server_no_context_takeover')) self.assertEqual( None, response.get_parameter_value('server_no_context_takeover')) extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) extension.add_parameter('client_max_window_bits', None) processor = PerMessageDeflateExtensionProcessor(extension) processor.set_client_max_window_bits(8) processor.set_client_no_context_takeover(True) response = processor.get_extension_response() self.assertEqual( '8', response.get_parameter_value('client_max_window_bits')) self.assertTrue(response.has_parameter('client_no_context_takeover')) self.assertEqual( None, response.get_parameter_value('client_no_context_takeover'))
def test_send_message(self): extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) request = _create_request_from_rawdata( b'', permessage_deflate_request=extension) msgutil.send_message(request, 'Hello') compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) compressed_hello = compress.compress(b'Hello') compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) compressed_hello = compressed_hello[:-4] expected = b'\xc1%c' % len(compressed_hello) expected += compressed_hello self.assertEqual(expected, request.connection.written_data())