class TestOutputBuffer(unittest.TestCase): def setUp(self): self.output_buffer = OutputBuffer(enable_buffering=True) def test_get_buffer(self): self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer()) data1 = bytearray([i for i in range(20)]) self.output_buffer.enqueue_msgbytes(data1) self.output_buffer.flush() self.assertEqual(data1, self.output_buffer.get_buffer()) data2 = bytearray([i for i in range(20, 40)]) self.output_buffer.enqueue_msgbytes(data2) self.output_buffer.flush() self.assertEqual(data1, self.output_buffer.get_buffer()) new_index = 10 self.output_buffer.index = new_index self.assertEqual(data1[new_index:], self.output_buffer.get_buffer()) def test_advance_buffer(self): with self.assertRaises(ValueError): self.output_buffer.advance_buffer(5) data1 = bytearray([i for i in range(20)]) self.output_buffer.enqueue_msgbytes(data1) self.output_buffer.flush() data2 = bytearray([i for i in range(20, 40)]) self.output_buffer.enqueue_msgbytes(data2) self.output_buffer.flush() self.output_buffer.advance_buffer(10) self.assertEqual(10, self.output_buffer.index) self.assertEqual(30, self.output_buffer.length) self.output_buffer.advance_buffer(10) self.assertEqual(0, self.output_buffer.index) self.assertEqual(1, len(self.output_buffer.output_msgs)) def test_at_msg_boundary(self): self.assertTrue(self.output_buffer.at_msg_boundary()) self.output_buffer.index = 1 self.assertFalse(self.output_buffer.at_msg_boundary()) def test_enqueue_msgbytes(self): with self.assertRaises(ValueError): self.output_buffer.enqueue_msgbytes("f") data1 = bytearray([i for i in range(20)]) self.output_buffer.enqueue_msgbytes(data1) self.output_buffer.flush() self.assertEqual(data1, self.output_buffer.get_buffer()) data2 = bytearray([i for i in range(20, 40)]) self.output_buffer.enqueue_msgbytes(data2) self.output_buffer.flush() self.assertEqual(data1, self.output_buffer.get_buffer()) new_index = 10 self.output_buffer.index = new_index self.assertEqual(data1[new_index:], self.output_buffer.get_buffer()) def test_prepend_msgbytes(self): with self.assertRaises(ValueError): self.output_buffer.prepend_msgbytes("f") data1 = bytearray([i for i in range(20)]) self.output_buffer.prepend_msgbytes(data1) data2 = bytearray([i for i in range(20, 40)]) self.output_buffer.prepend_msgbytes(data2) confirm1 = deque() confirm1.append(data2) confirm1.append(data1) self.assertEqual(confirm1, self.output_buffer.output_msgs) self.assertEqual(40, self.output_buffer.length) self.output_buffer.advance_buffer(10) data3 = bytearray([i for i in range(40, 60)]) self.output_buffer.prepend_msgbytes(data3) confirm2 = deque() confirm2.append(data2) confirm2.append(data3) confirm2.append(data1) self.assertEqual(confirm2, self.output_buffer.output_msgs) self.assertEqual(50, self.output_buffer.length) def test_has_more_bytes(self): self.assertFalse(self.output_buffer.has_more_bytes()) self.output_buffer.length = 1 self.assertTrue(self.output_buffer.has_more_bytes()) def test_flush_get_buffer_on_time(self): data1 = bytearray(i for i in range(20)) self.output_buffer.enqueue_msgbytes(data1) self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer()) time.time = MagicMock(return_value=time.time() + OUTPUT_BUFFER_BATCH_MAX_HOLD_TIME + 0.001) self.assertEqual(data1, self.output_buffer.get_buffer()) def test_flush_get_buffer_on_size(self): data1 = bytearray(i for i in range(20)) self.output_buffer.enqueue_msgbytes(data1) self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer()) data2 = bytearray(1 for _ in range(OUTPUT_BUFFER_MIN_SIZE)) self.output_buffer.enqueue_msgbytes(data2) self.assertNotEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer()) def test_safe_empty(self): self.output_buffer = OutputBuffer(enable_buffering=False) messages = [ helpers.generate_bytearray(10), helpers.generate_bytearray(10) ] for message in messages: self.output_buffer.enqueue_msgbytes(message) self.output_buffer.advance_buffer(5) self.assertEqual(15, len(self.output_buffer)) self.output_buffer.safe_empty() self.assertEqual(5, len(self.output_buffer)) def test_safe_empty_no_contents(self): self.output_buffer = OutputBuffer(enable_buffering=False) self.output_buffer.safe_empty() def test_safe_empty_buffering(self): messages = [ helpers.generate_bytearray(10), helpers.generate_bytearray(10) ] for message in messages: self.output_buffer.enqueue_msgbytes(message) self.assertEqual(20, len(self.output_buffer)) self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer()) self.output_buffer.safe_empty() self.assertEqual(0, len(self.output_buffer)) self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer())
class MessageTrackerTest(AbstractTestCase): def setUp(self) -> None: self.node = MockNode( helpers.get_common_opts(1001, external_ip="128.128.128.128")) self.tracker = MessageTracker( MockConnection(MockSocketConnection(), self.node)) self.output_buffer = OutputBuffer(enable_buffering=True) def test_empty_bytes_no_bytes_sent(self): message = TxMessage( helpers.generate_object_hash(), 5, tx_val=helpers.generate_bytearray(250), ) message_length = len(message.rawbytes()) self.output_buffer.enqueue_msgbytes(message.rawbytes()) self.output_buffer.flush() self.tracker.append_message(message_length, message) self.output_buffer.safe_empty() self.tracker.empty_bytes(self.output_buffer.length) self.assertEqual(0, self.output_buffer.length) self.assertEqual(0, self.tracker.bytes_remaining) self.assertEqual(0, len(self.tracker.messages)) def test_empty_bytes(self): message1 = TxMessage( helpers.generate_object_hash(), 5, tx_val=helpers.generate_bytearray(250), ) message2 = TxMessage( helpers.generate_object_hash(), 5, tx_val=helpers.generate_bytearray(250), ) message3 = TxMessage( helpers.generate_object_hash(), 5, tx_val=helpers.generate_bytearray(250), ) message_length = len(message1.rawbytes()) self.output_buffer.enqueue_msgbytes(message1.rawbytes()) self.output_buffer.flush() self.output_buffer.enqueue_msgbytes(message2.rawbytes()) self.output_buffer.enqueue_msgbytes(message3.rawbytes()) self.tracker.append_message(message_length, message1) self.tracker.append_message(message_length, message2) self.tracker.append_message(message_length, message3) self.output_buffer.advance_buffer(120) self.tracker.advance_bytes(120) self.output_buffer.safe_empty() self.assertEqual(message_length - 120, self.output_buffer.length) self.tracker.empty_bytes(self.output_buffer.length) self.assertEqual(1, len(self.tracker.messages)) self.assertEqual(message_length - 120, self.tracker.bytes_remaining) self.assertEqual(120, self.tracker.messages[0].sent_bytes) def test_empty_bytes_more_bytes(self): total_bytes = 0 for _ in range(100): message = TxMessage( helpers.generate_object_hash(), 5, tx_val=helpers.generate_bytearray(2500), ) message_length = len(message.rawbytes()) total_bytes += message_length self.output_buffer.enqueue_msgbytes(message.rawbytes()) self.tracker.append_message(message_length, message) self.output_buffer.advance_buffer(3500) self.tracker.advance_bytes(3500) self.output_buffer.safe_empty() self.tracker.empty_bytes(self.output_buffer.length) self.assertEqual(self.output_buffer.length, self.tracker.bytes_remaining)