def test_block_headers_request(self): get_headers = GetBlockHeadersEthProtocolMessage(None, self.BLOCK_HASH, 111, 222, 0) # Reply with empty headers to the first get headers request for fast sync mode support get_headers_frames = map(self.eth_node_cipher.encrypt_frame, frame_utils.get_frames(get_headers.msg_type, get_headers.rawbytes())) for get_headers_frame in get_headers_frames: helpers.receive_node_message(self.gateway_node, self.local_node_fileno, get_headers_frame) self.eth_remote_node_connection.enqueue_msg.assert_not_called() self.eth_node_connection.enqueue_msg.assert_called_once_with(BlockHeadersEthProtocolMessage(None, [])) # The second get headers message should be proxied to remote blockchain node get_headers_frames = map(self.eth_node_cipher.encrypt_frame, frame_utils.get_frames(get_headers.msg_type, get_headers.rawbytes())) for get_headers_frame in get_headers_frames: helpers.receive_node_message(self.gateway_node, self.local_node_fileno, get_headers_frame) self.eth_remote_node_connection.enqueue_msg.assert_called_once_with(get_headers) headers = BlockHeadersEthProtocolMessage(None, [ mock_eth_messages.get_dummy_block_header(1), mock_eth_messages.get_dummy_block_header(2) ]) headers_frames = map(self.eth_node_cipher.encrypt_frame, frame_utils.get_frames(headers.msg_type, headers.rawbytes())) for headers_frame in headers_frames: helpers.receive_node_message(self.gateway_node, self.remote_node_fileno, headers_frame) self.eth_node_connection.enqueue_msg.assert_called_with(headers)
def test_normal_frame(self): cipher1, cipher2 = self.setup_ciphers() dummy_msg_type = 1 dummy_payload = helpers.generate_bytearray(123) dummy_protocol = 0 frames = frame_utils.get_frames( dummy_msg_type, memoryview(dummy_payload), dummy_protocol, eth_common_constants.DEFAULT_FRAME_SIZE) input_buffer = InputBuffer() for frame in frames: encrypted_frame = cipher1.encrypt_frame(frame) encrypted_frame_bytes = bytearray( rlp_utils.str_to_bytes(encrypted_frame)) input_buffer.add_bytes(encrypted_frame_bytes) # adding some dummy bytes but less than header size len dummy_bytes_len = int(eth_common_constants.FRAME_HDR_TOTAL_LEN / 2) input_buffer.add_bytes(helpers.generate_bytearray(dummy_bytes_len)) framed_input_buffer = FramedInputBuffer(cipher2) is_full, msg_type = framed_input_buffer.peek_message(input_buffer) self.assertTrue(is_full) self.assertEqual(msg_type, dummy_msg_type) message, full_msg_type = framed_input_buffer.get_full_message() self.assertEqual(message, dummy_payload) self.assertEqual(full_msg_type, dummy_msg_type) self.assertEqual(input_buffer.length, dummy_bytes_len)
def test_get_frames__chunked(self): msg_type = 1 expected_frames_count = 4 dummy_payload = memoryview( helpers.generate_bytearray(self.TEST_FRAME_SIZE) * (expected_frames_count - 1)) dummy_protocol_id = 0 frames = frame_utils.get_frames(msg_type, dummy_payload, dummy_protocol_id, window_size=self.TEST_FRAME_SIZE) self.assertTrue(frames) self.assertEqual(len(frames), expected_frames_count) all_payload = bytearray(0) expected_sequence_id = 0 for frame in frames: self.assertTrue(frame.get_payload()) self.assertEqual(frame.get_protocol_id(), dummy_protocol_id) self.assertEqual(frame.get_sequence_id(), expected_sequence_id) self.assertTrue(frame.is_chunked()) # Verify that only the first of chunked frames has information about total payload length if expected_sequence_id == 0: self.assertEqual(frame.get_msg_type(), msg_type) expected_total_size = len( dummy_payload) + eth_common_constants.FRAME_MSG_TYPE_LEN self.assertEqual(frame.get_total_payload_size(), expected_total_size) else: self.assertIsNone(frame.get_msg_type()) self.assertIsNone(frame.get_total_payload_size()) self.assertTrue(frame.get_header()) self.assertTrue(frame.get_body()) all_payload += frame.get_payload() expected_sequence_id += 1 self.assertEqual(all_payload, dummy_payload)
def test_chunked_frames(self): cipher1, cipher2 = self.setup_ciphers() dummy_msg_type = 10 expected_frames_count = 10 dummy_payload = helpers.generate_bytearray(self.TEST_FRAME_SIZE * (expected_frames_count - 1)) dummy_protocol = 0 frames = frame_utils.get_frames(dummy_msg_type, memoryview(dummy_payload), dummy_protocol, self.TEST_FRAME_SIZE) self.assertEqual(len(frames), expected_frames_count) frames_bytes = bytearray(0) for frame in frames: encrypted_frame = cipher1.encrypt_frame(frame) encrypted_frame_bytes = bytearray( rlp_utils.str_to_bytes(encrypted_frame)) frames_bytes += encrypted_frame_bytes # adding some dummy bytes but less than header size len dummy_bytes_len = int(eth_common_constants.FRAME_HDR_TOTAL_LEN / 2) frames_bytes += helpers.generate_bytearray(dummy_bytes_len) input_buffer = InputBuffer() framed_input_buffer = FramedInputBuffer(cipher2) read_start = 0 read_size = eth_common_constants.FRAME_HDR_TOTAL_LEN is_full = False msg_type = None while not is_full and read_start < len(frames_bytes): input_buffer.add_bytes(frames_bytes[read_start:read_start + read_size]) is_full, msg_type = framed_input_buffer.peek_message(input_buffer) read_start += read_size self.assertTrue(is_full) self.assertEqual(msg_type, dummy_msg_type)
def get_message_bytes(self, msg): if isinstance(msg, RawEthProtocolMessage): yield msg.rawbytes() else: serialization_start_time = time.time() frames = frame_utils.get_frames( msg.msg_type, msg.rawbytes(), eth_common_constants.DEFAULT_FRAME_PROTOCOL_ID, eth_common_constants.DEFAULT_FRAME_SIZE) eth_gateway_stats_service.log_serialized_message( time.time() - serialization_start_time) assert frames self.connection.log_trace("Broke message into {} frames", len(frames)) encryption_start_time = time.time() for frame in frames: yield self.rlpx_cipher.encrypt_frame(frame) eth_gateway_stats_service.log_encrypted_message( time.time() - encryption_start_time)
def test_encrypt_decrypt_frame__normal_frame(self): cipher1, cipher2 = self.setup_ciphers() msg_type = 1 dummy_payload = memoryview(helpers.generate_bytearray(123)) dummy_protocol_id = 0 frames = frame_utils.get_frames(msg_type, dummy_payload, dummy_protocol_id, window_size=self.TEST_FRAME_SIZE) self.assertTrue(frames) self.assertEqual(len(frames), 1) frame = frames[0] encrypted_frame = memoryview(cipher1.encrypt_frame(frame)) self.assertTrue(encrypted_frame) decrypted_frame = self._decrypt_frame(encrypted_frame, cipher2) self._assert_frames_equal(decrypted_frame, frame)
def test_encrypt_decrypt_frame__chunked_frame(self): cipher1, cipher2 = self.setup_ciphers() msg_type = 1 expected_frames_count = 3 dummy_payload = memoryview( helpers.generate_bytearray(self.TEST_FRAME_SIZE * (expected_frames_count - 1))) dummy_protocol_id = 0 frames = frame_utils.get_frames(msg_type, dummy_payload, dummy_protocol_id, window_size=self.TEST_FRAME_SIZE) self.assertTrue(frames) self.assertEqual(len(frames), expected_frames_count) for frame in frames: encrypted_frame = memoryview(cipher1.encrypt_frame(frame)) self.assertTrue(encrypted_frame) decrypted_frame = self._decrypt_frame(encrypted_frame, cipher2) self._assert_frames_equal(decrypted_frame, frame)
def test_get_frames__normal_frame(self): msg_type = 1 dummy_payload = memoryview(helpers.generate_bytearray(123)) dummy_protocol_id = 0 frames = frame_utils.get_frames(msg_type, dummy_payload, dummy_protocol_id, window_size=self.TEST_FRAME_SIZE) self.assertTrue(frames) self.assertEqual(len(frames), 1) frame = frames[0] self.assertTrue(frame) self.assertEqual(frame.get_msg_type(), msg_type) self.assertEqual(frame.get_payload(), dummy_payload) self.assertEqual(frame.get_protocol_id(), dummy_protocol_id) self.assertIsNone(frame.get_sequence_id()) self.assertFalse(frame.is_chunked()) self.assertTrue(frame.get_header()) self.assertTrue(frame.get_body())