class TestOutputBuffer(unittest.TestCase):
    def setUp(self):
        self.output_buffer = OutputBuffer(enable_buffering=True)

    def test_get_buffer(self):
        self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer())

        data1 = bytearray([i for i in range(20)])
        self.output_buffer.enqueue_msgbytes(data1)
        self.output_buffer.flush()
        self.assertEqual(data1, self.output_buffer.get_buffer())

        data2 = bytearray([i for i in range(20, 40)])
        self.output_buffer.enqueue_msgbytes(data2)
        self.output_buffer.flush()
        self.assertEqual(data1, self.output_buffer.get_buffer())

        new_index = 10
        self.output_buffer.index = new_index
        self.assertEqual(data1[new_index:], self.output_buffer.get_buffer())

    def test_advance_buffer(self):
        with self.assertRaises(ValueError):
            self.output_buffer.advance_buffer(5)

        data1 = bytearray([i for i in range(20)])
        self.output_buffer.enqueue_msgbytes(data1)
        self.output_buffer.flush()
        data2 = bytearray([i for i in range(20, 40)])
        self.output_buffer.enqueue_msgbytes(data2)
        self.output_buffer.flush()

        self.output_buffer.advance_buffer(10)
        self.assertEqual(10, self.output_buffer.index)
        self.assertEqual(30, self.output_buffer.length)

        self.output_buffer.advance_buffer(10)
        self.assertEqual(0, self.output_buffer.index)
        self.assertEqual(1, len(self.output_buffer.output_msgs))

    def test_at_msg_boundary(self):
        self.assertTrue(self.output_buffer.at_msg_boundary())
        self.output_buffer.index = 1
        self.assertFalse(self.output_buffer.at_msg_boundary())

    def test_enqueue_msgbytes(self):
        with self.assertRaises(ValueError):
            self.output_buffer.enqueue_msgbytes("f")

        data1 = bytearray([i for i in range(20)])
        self.output_buffer.enqueue_msgbytes(data1)
        self.output_buffer.flush()
        self.assertEqual(data1, self.output_buffer.get_buffer())

        data2 = bytearray([i for i in range(20, 40)])
        self.output_buffer.enqueue_msgbytes(data2)
        self.output_buffer.flush()
        self.assertEqual(data1, self.output_buffer.get_buffer())

        new_index = 10
        self.output_buffer.index = new_index
        self.assertEqual(data1[new_index:], self.output_buffer.get_buffer())

    def test_prepend_msgbytes(self):
        with self.assertRaises(ValueError):
            self.output_buffer.prepend_msgbytes("f")

        data1 = bytearray([i for i in range(20)])
        self.output_buffer.prepend_msgbytes(data1)

        data2 = bytearray([i for i in range(20, 40)])
        self.output_buffer.prepend_msgbytes(data2)

        confirm1 = deque()
        confirm1.append(data2)
        confirm1.append(data1)

        self.assertEqual(confirm1, self.output_buffer.output_msgs)
        self.assertEqual(40, self.output_buffer.length)

        self.output_buffer.advance_buffer(10)

        data3 = bytearray([i for i in range(40, 60)])
        self.output_buffer.prepend_msgbytes(data3)

        confirm2 = deque()
        confirm2.append(data2)
        confirm2.append(data3)
        confirm2.append(data1)

        self.assertEqual(confirm2, self.output_buffer.output_msgs)
        self.assertEqual(50, self.output_buffer.length)

    def test_has_more_bytes(self):
        self.assertFalse(self.output_buffer.has_more_bytes())
        self.output_buffer.length = 1
        self.assertTrue(self.output_buffer.has_more_bytes())

    def test_flush_get_buffer_on_time(self):
        data1 = bytearray(i for i in range(20))
        self.output_buffer.enqueue_msgbytes(data1)
        self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer())

        time.time = MagicMock(return_value=time.time() +
                              OUTPUT_BUFFER_BATCH_MAX_HOLD_TIME + 0.001)
        self.assertEqual(data1, self.output_buffer.get_buffer())

    def test_flush_get_buffer_on_size(self):
        data1 = bytearray(i for i in range(20))
        self.output_buffer.enqueue_msgbytes(data1)
        self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer())

        data2 = bytearray(1 for _ in range(OUTPUT_BUFFER_MIN_SIZE))
        self.output_buffer.enqueue_msgbytes(data2)
        self.assertNotEqual(OutputBuffer.EMPTY,
                            self.output_buffer.get_buffer())

    def test_safe_empty(self):
        self.output_buffer = OutputBuffer(enable_buffering=False)
        messages = [
            helpers.generate_bytearray(10),
            helpers.generate_bytearray(10)
        ]
        for message in messages:
            self.output_buffer.enqueue_msgbytes(message)

        self.output_buffer.advance_buffer(5)
        self.assertEqual(15, len(self.output_buffer))

        self.output_buffer.safe_empty()
        self.assertEqual(5, len(self.output_buffer))

    def test_safe_empty_no_contents(self):
        self.output_buffer = OutputBuffer(enable_buffering=False)
        self.output_buffer.safe_empty()

    def test_safe_empty_buffering(self):
        messages = [
            helpers.generate_bytearray(10),
            helpers.generate_bytearray(10)
        ]
        for message in messages:
            self.output_buffer.enqueue_msgbytes(message)

        self.assertEqual(20, len(self.output_buffer))
        self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer())

        self.output_buffer.safe_empty()
        self.assertEqual(0, len(self.output_buffer))
        self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer())
class MessageTrackerTest(AbstractTestCase):
    def setUp(self) -> None:
        self.node = MockNode(
            helpers.get_common_opts(1001, external_ip="128.128.128.128"))
        self.tracker = MessageTracker(
            MockConnection(MockSocketConnection(), self.node))
        self.output_buffer = OutputBuffer(enable_buffering=True)

    def test_empty_bytes_no_bytes_sent(self):
        message = TxMessage(
            helpers.generate_object_hash(),
            5,
            tx_val=helpers.generate_bytearray(250),
        )
        message_length = len(message.rawbytes())
        self.output_buffer.enqueue_msgbytes(message.rawbytes())
        self.output_buffer.flush()
        self.tracker.append_message(message_length, message)

        self.output_buffer.safe_empty()
        self.tracker.empty_bytes(self.output_buffer.length)

        self.assertEqual(0, self.output_buffer.length)
        self.assertEqual(0, self.tracker.bytes_remaining)
        self.assertEqual(0, len(self.tracker.messages))

    def test_empty_bytes(self):
        message1 = TxMessage(
            helpers.generate_object_hash(),
            5,
            tx_val=helpers.generate_bytearray(250),
        )
        message2 = TxMessage(
            helpers.generate_object_hash(),
            5,
            tx_val=helpers.generate_bytearray(250),
        )
        message3 = TxMessage(
            helpers.generate_object_hash(),
            5,
            tx_val=helpers.generate_bytearray(250),
        )
        message_length = len(message1.rawbytes())

        self.output_buffer.enqueue_msgbytes(message1.rawbytes())
        self.output_buffer.flush()
        self.output_buffer.enqueue_msgbytes(message2.rawbytes())
        self.output_buffer.enqueue_msgbytes(message3.rawbytes())

        self.tracker.append_message(message_length, message1)
        self.tracker.append_message(message_length, message2)
        self.tracker.append_message(message_length, message3)

        self.output_buffer.advance_buffer(120)
        self.tracker.advance_bytes(120)

        self.output_buffer.safe_empty()
        self.assertEqual(message_length - 120, self.output_buffer.length)

        self.tracker.empty_bytes(self.output_buffer.length)

        self.assertEqual(1, len(self.tracker.messages))
        self.assertEqual(message_length - 120, self.tracker.bytes_remaining)
        self.assertEqual(120, self.tracker.messages[0].sent_bytes)

    def test_empty_bytes_more_bytes(self):
        total_bytes = 0
        for _ in range(100):
            message = TxMessage(
                helpers.generate_object_hash(),
                5,
                tx_val=helpers.generate_bytearray(2500),
            )
            message_length = len(message.rawbytes())
            total_bytes += message_length
            self.output_buffer.enqueue_msgbytes(message.rawbytes())
            self.tracker.append_message(message_length, message)

        self.output_buffer.advance_buffer(3500)
        self.tracker.advance_bytes(3500)

        self.output_buffer.safe_empty()
        self.tracker.empty_bytes(self.output_buffer.length)

        self.assertEqual(self.output_buffer.length,
                         self.tracker.bytes_remaining)