def test_batch_done(self):
        tp0 = TopicPartition("test-topic", 0)
        tp1 = TopicPartition("test-topic", 1)
        tp2 = TopicPartition("test-topic", 2)
        tp3 = TopicPartition("test-topic", 3)

        def mocked_leader_for_partition(tp):
            if tp == tp0:
                return 0
            if tp == tp1:
                return 1
            if tp == tp2:
                return -1
            return None

        cluster = ClusterMetadata(metadata_max_age_ms=10000)
        cluster.leader_for_partition = mock.MagicMock()
        cluster.leader_for_partition.side_effect = mocked_leader_for_partition

        ma = MessageAccumulator(cluster, 1000, None, 1, self.loop)
        fut1 = yield from ma.add_message(
            tp2, None, b'msg for tp@2', timeout=2)
        fut2 = yield from ma.add_message(
            tp3, None, b'msg for tp@3', timeout=2)
        yield from ma.add_message(tp1, None, b'0123456789'*70, timeout=2)
        with self.assertRaises(KafkaTimeoutError):
            yield from ma.add_message(tp1, None, b'0123456789'*70, timeout=2)
        batches, _ = ma.drain_by_nodes(ignore_nodes=[])
        self.assertEqual(batches[1][tp1].expired(), True)
        with self.assertRaises(LeaderNotAvailableError):
            yield from fut1
        with self.assertRaises(NotLeaderForPartitionError):
            yield from fut2

        fut01 = yield from ma.add_message(
            tp0, b'key0', b'value#0', timeout=2)
        fut02 = yield from ma.add_message(
            tp0, b'key1', b'value#1', timeout=2)
        fut10 = yield from ma.add_message(
            tp1, None, b'0123456789'*70, timeout=2)
        batches, _ = ma.drain_by_nodes(ignore_nodes=[])
        self.assertEqual(batches[0][tp0].expired(), False)
        self.assertEqual(batches[1][tp1].expired(), False)
        batch_data = batches[0][tp0].get_data_buffer()
        self.assertEqual(type(batch_data), io.BytesIO)
        batches[0][tp0].done(base_offset=10)

        class TestException(Exception):
            pass

        batches[1][tp1].done(exception=TestException())

        res = yield from fut01
        self.assertEqual(res.topic, "test-topic")
        self.assertEqual(res.partition, 0)
        self.assertEqual(res.offset, 10)
        res = yield from fut02
        self.assertEqual(res.topic, "test-topic")
        self.assertEqual(res.partition, 0)
        self.assertEqual(res.offset, 11)
        with self.assertRaises(TestException):
            yield from fut10

        fut01 = yield from ma.add_message(
            tp0, b'key0', b'value#0', timeout=2)
        batches, _ = ma.drain_by_nodes(ignore_nodes=[])
        batches[0][tp0].done(base_offset=None)
        res = yield from fut01
        self.assertEqual(res, None)

        # cancelling future
        fut01 = yield from ma.add_message(
            tp0, b'key0', b'value#2', timeout=2)
        batches, _ = ma.drain_by_nodes(ignore_nodes=[])
        fut01.cancel()
        batches[0][tp0].done(base_offset=21)  # no error in this case
    def test_basic(self):
        cluster = ClusterMetadata(metadata_max_age_ms=10000)
        ma = MessageAccumulator(cluster, 1000, None, 30, self.loop)
        data_waiter = ma.data_waiter()
        done, _ = yield from asyncio.wait(
            [data_waiter], timeout=0.2, loop=self.loop)
        self.assertFalse(bool(done))  # no data in accumulator yet...

        tp0 = TopicPartition("test-topic", 0)
        tp1 = TopicPartition("test-topic", 1)
        yield from ma.add_message(tp0, b'key', b'value', timeout=2)
        yield from ma.add_message(tp1, None, b'value without key', timeout=2)

        done, _ = yield from asyncio.wait(
            [data_waiter], timeout=0.2, loop=self.loop)
        self.assertTrue(bool(done))

        batches, unknown_leaders_exist = ma.drain_by_nodes(ignore_nodes=[])
        self.assertEqual(batches, {})
        self.assertEqual(unknown_leaders_exist, True)

        def mocked_leader_for_partition(tp):
            if tp == tp0:
                return 0
            if tp == tp1:
                return 1
            return -1

        cluster.leader_for_partition = mock.MagicMock()
        cluster.leader_for_partition.side_effect = mocked_leader_for_partition
        batches, unknown_leaders_exist = ma.drain_by_nodes(ignore_nodes=[])
        self.assertEqual(len(batches), 2)
        self.assertEqual(unknown_leaders_exist, False)
        m_set0 = batches[0].get(tp0)
        self.assertEqual(type(m_set0), MessageBatch)
        m_set1 = batches[1].get(tp1)
        self.assertEqual(type(m_set1), MessageBatch)
        self.assertEqual(m_set0.expired(), False)

        data_waiter = ensure_future(ma.data_waiter(), loop=self.loop)
        done, _ = yield from asyncio.wait(
            [data_waiter], timeout=0.2, loop=self.loop)
        self.assertFalse(bool(done))  # no data in accumulator again...

        # testing batch overflow
        tp2 = TopicPartition("test-topic", 2)
        yield from ma.add_message(
            tp0, None, b'some short message', timeout=2)
        yield from ma.add_message(
            tp0, None, b'some other short message', timeout=2)
        yield from ma.add_message(
            tp1, None, b'0123456789' * 70, timeout=2)
        yield from ma.add_message(
            tp2, None, b'message to unknown leader', timeout=2)
        # next we try to add message with len=500,
        # as we have buffer_size=1000 coroutine will block until data will be
        # drained
        add_task = ensure_future(
            ma.add_message(tp1, None, b'0123456789' * 50, timeout=2),
            loop=self.loop)
        done, _ = yield from asyncio.wait(
            [add_task], timeout=0.2, loop=self.loop)
        self.assertFalse(bool(done))

        batches, unknown_leaders_exist = ma.drain_by_nodes(ignore_nodes=[1, 2])
        self.assertEqual(unknown_leaders_exist, True)
        m_set0 = batches[0].get(tp0)
        self.assertEqual(m_set0._builder._relative_offset, 2)
        m_set1 = batches[1].get(tp1)
        self.assertEqual(m_set1, None)

        done, _ = yield from asyncio.wait(
            [add_task], timeout=0.1, loop=self.loop)
        self.assertFalse(bool(done))  # we stil not drained data for tp1

        batches, unknown_leaders_exist = ma.drain_by_nodes(ignore_nodes=[])
        self.assertEqual(unknown_leaders_exist, True)
        m_set0 = batches[0].get(tp0)
        self.assertEqual(m_set0, None)
        m_set1 = batches[1].get(tp1)
        self.assertEqual(m_set1._builder._relative_offset, 1)

        done, _ = yield from asyncio.wait(
            [add_task], timeout=0.2, loop=self.loop)
        self.assertTrue(bool(done))
        batches, unknown_leaders_exist = ma.drain_by_nodes(ignore_nodes=[])
        self.assertEqual(unknown_leaders_exist, True)
        m_set1 = batches[1].get(tp1)
        self.assertEqual(m_set1._builder._relative_offset, 1)
    def test_batch_pending_batch_list(self):
        # In message accumulator we have _pending_batches list, that stores
        # batches when those are delivered to node. We must be sure we never
        # lose a batch during retries and that we don't produce duplicate batch
        # links in the process

        tp0 = TopicPartition("test-topic", 0)

        def mocked_leader_for_partition(tp):
            if tp == tp0:
                return 0
            return None

        cluster = ClusterMetadata(metadata_max_age_ms=10000)
        cluster.leader_for_partition = mock.MagicMock()
        cluster.leader_for_partition.side_effect = mocked_leader_for_partition

        ma = MessageAccumulator(cluster, 1000, 0, 1, loop=self.loop)
        fut1 = yield from ma.add_message(
            tp0, b'key', b'value', timeout=2)

        # Drain and Reenqueu
        batches, _ = ma.drain_by_nodes(ignore_nodes=[])
        batch = batches[0][tp0]
        self.assertIn(batch, ma._pending_batches)
        self.assertFalse(ma._batches)
        self.assertFalse(fut1.done())

        ma.reenqueue(batch)
        self.assertEqual(batch.retry_count, 1)
        self.assertFalse(ma._pending_batches)
        self.assertIn(batch, ma._batches[tp0])
        self.assertFalse(fut1.done())

        # Drain and Reenqueu again. We check for repeated call
        batches, _ = ma.drain_by_nodes(ignore_nodes=[])
        self.assertEqual(batches[0][tp0], batch)
        self.assertEqual(batch.retry_count, 2)
        self.assertIn(batch, ma._pending_batches)
        self.assertFalse(ma._batches)
        self.assertFalse(fut1.done())

        ma.reenqueue(batch)
        self.assertEqual(batch.retry_count, 2)
        self.assertFalse(ma._pending_batches)
        self.assertIn(batch, ma._batches[tp0])
        self.assertFalse(fut1.done())

        # Drain and mark as done. Check that no link to batch remained
        batches, _ = ma.drain_by_nodes(ignore_nodes=[])
        self.assertEqual(batches[0][tp0], batch)
        self.assertEqual(batch.retry_count, 3)
        self.assertIn(batch, ma._pending_batches)
        self.assertFalse(ma._batches)
        self.assertFalse(fut1.done())

        if hasattr(batch.future, "_callbacks"):  # Vanilla asyncio
            self.assertEqual(len(batch.future._callbacks), 1)

        batch.done_noack()
        yield from asyncio.sleep(0.01, loop=self.loop)
        self.assertEqual(batch.retry_count, 3)
        self.assertFalse(ma._pending_batches)
        self.assertFalse(ma._batches)