Example #1
0
    async def test_no_concurrent_send_on_connection(self):
        client = AIOKafkaClient(bootstrap_servers=self.hosts,
                                metadata_max_age_ms=10000)
        await client.bootstrap()
        self.add_cleanup(client.close)

        await self.wait_topic(client, self.topic)

        node_id = client.get_random_node()
        wait_request = FetchRequest_v0(
            -1,  # replica_id
            500,  # max_wait_ms
            1024 * 1024,  # min_bytes
            [(self.topic, [(0, 0, 1024)])])
        vanila_request = MetadataRequest([])

        loop = get_running_loop()
        send_time = loop.time()
        long_task = create_task(client.send(node_id, wait_request))
        await asyncio.sleep(0.0001)
        self.assertFalse(long_task.done())

        await client.send(node_id, vanila_request)
        resp_time = loop.time()
        fetch_resp = await long_task
        # Check error code like resp->topics[0]->partitions[0]->error_code
        self.assertEqual(fetch_resp.topics[0][1][0][1], 0)

        # Check that vanila request actually executed after wait request
        self.assertGreaterEqual(resp_time - send_time, 0.5)
Example #2
0
    def __init__(self, topics: Iterable[str], loop=None):
        if loop is None:
            loop = get_running_loop()

        self._topics = frozenset(topics)  # type: FrozenSet[str]
        self._assignment = None  # type: Assignment
        self.unsubscribe_future = loop.create_future()  # type: Future
        self._reassignment_in_progress = True
Example #3
0
 async def test_global_loop_for_create_conn(self):
     loop = get_running_loop()
     host, port = self.kafka_host, self.kafka_port
     conn = await create_conn(host, port)
     self.assertIs(conn._loop, loop)
     conn.close()
     # make sure second closing does nothing and we have full coverage
     # of *if self._reader:* condition
     conn.close()
Example #4
0
    def __init__(self, loop=None):
        if loop is None:
            loop = get_running_loop()
        self._loop = loop

        self._subscription_waiters = []  # type: List[Future]
        self._assignment_waiters = []  # type: List[Future]

        # Fetch contexts
        self._fetch_count = 0
        self._last_fetch_ended = time.monotonic()
    async def test_add_batch_builder(self):
        tp0 = TopicPartition("test-topic", 0)
        tp1 = TopicPartition("test-topic", 1)

        def mocked_leader_for_partition(tp):
            if tp == tp0:
                return 0
            if tp == tp1:
                return 1
            return None

        cluster = ClusterMetadata(metadata_max_age_ms=10000)
        cluster.leader_for_partition = mock.MagicMock()
        cluster.leader_for_partition.side_effect = mocked_leader_for_partition

        ma = MessageAccumulator(cluster, 1000, 0, 1)
        builder0 = ma.create_builder()
        builder1_1 = ma.create_builder()
        builder1_2 = ma.create_builder()

        # batches may queued one-per-TP
        self.assertFalse(ma._wait_data_future.done())
        await ma.add_batch(builder0, tp0, 1)
        self.assertTrue(ma._wait_data_future.done())
        self.assertEqual(len(ma._batches[tp0]), 1)

        await ma.add_batch(builder1_1, tp1, 1)
        self.assertEqual(len(ma._batches[tp1]), 1)
        with self.assertRaises(KafkaTimeoutError):
            await ma.add_batch(builder1_2, tp1, 0.1)
        self.assertTrue(ma._wait_data_future.done())
        self.assertEqual(len(ma._batches[tp1]), 1)

        # second batch gets added once the others are cleared out
        get_running_loop().call_later(0.1, ma.drain_by_nodes, [])
        await ma.add_batch(builder1_2, tp1, 1)
        self.assertTrue(ma._wait_data_future.done())
        self.assertEqual(len(ma._batches[tp0]), 0)
        self.assertEqual(len(ma._batches[tp1]), 1)
Example #6
0
    def __init__(
            self, cluster, batch_size, compression_type, batch_ttl, *,
            txn_manager=None, loop=None):
        if loop is None:
            loop = get_running_loop()
        self._loop = loop
        self._batches = collections.defaultdict(collections.deque)
        self._pending_batches = set([])
        self._cluster = cluster
        self._batch_size = batch_size
        self._compression_type = compression_type
        self._batch_ttl = batch_ttl
        self._wait_data_future = loop.create_future()
        self._closed = False
        self._api_version = (0, 9)
        self._txn_manager = txn_manager

        self._exception = None  # Critical exception
Example #7
0
    async def start(self):
        """Connect to Kafka cluster and check server version"""
        assert self._loop is get_running_loop(), (
            "Please create objects with the same loop as running with")
        log.debug("Starting the Kafka producer")  # trace
        await self.client.bootstrap()

        if self._compression_type == 'lz4':
            assert self.client.api_version >= (0, 8, 2), \
                'LZ4 Requires >= Kafka 0.8.2 Brokers'

        if self._txn_manager is not None and self.client.api_version < (0, 11):
            raise UnsupportedVersionError(
                "Idempotent producer available only for Broker version 0.11"
                " and above")

        await self._sender.start()
        self._message_accumulator.set_api_version(self.client.api_version)
        self._producer_magic = 0 if self.client.api_version < (0, 10) else 1
        log.debug("Kafka producer started")
Example #8
0
    async def test_concurrent_send_on_different_connection_groups(self):
        client = AIOKafkaClient(bootstrap_servers=self.hosts,
                                metadata_max_age_ms=10000)
        await client.bootstrap()
        self.add_cleanup(client.close)

        await self.wait_topic(client, self.topic)

        node_id = client.get_random_node()
        broker = client.cluster.broker_metadata(node_id)
        client.cluster.add_coordinator(node_id,
                                       broker.host,
                                       broker.port,
                                       rack=None,
                                       purpose=(CoordinationType.GROUP, ""))

        wait_request = FetchRequest_v0(
            -1,  # replica_id
            500,  # max_wait_ms
            1024 * 1024,  # min_bytes
            [(self.topic, [(0, 0, 1024)])])
        vanila_request = MetadataRequest([])

        loop = get_running_loop()
        send_time = loop.time()
        long_task = create_task(client.send(node_id, wait_request))
        await asyncio.sleep(0.0001)
        self.assertFalse(long_task.done())

        await client.send(node_id,
                          vanila_request,
                          group=ConnectionGroup.COORDINATION)
        resp_time = loop.time()
        self.assertFalse(long_task.done())

        fetch_resp = await long_task
        # Check error code like resp->topics[0]->partitions[0]->error_code
        self.assertEqual(fetch_resp.topics[0][1][0][1], 0)

        # Check that vanila request actually executed after wait request
        self.assertLess(resp_time - send_time, 0.5)
Example #9
0
    async def test_sender__handler_base_do(self):
        sender = await self._setup_sender()

        class MockHandler(BaseHandler):
            def create_request(self):
                return MetadataRequest[0]([])

        mock_handler = MockHandler(sender)
        mock_handler.handle_response = mock.Mock(return_value=0.1)
        success = await mock_handler.do(node_id=0)
        self.assertFalse(success)

        MockHandler.return_value = None
        mock_handler.handle_response = mock.Mock(return_value=None)
        success = await mock_handler.do(node_id=0)
        self.assertTrue(success)

        loop = get_running_loop()
        time = loop.time()
        sender.client.send = mock.Mock(side_effect=UnknownError())
        success = await mock_handler.do(node_id=0)
        self.assertFalse(success)
        self.assertAlmostEqual(loop.time() - time, 0.1, 1)
Example #10
0
    async def test_fetcher__update_fetch_positions(self):
        client = AIOKafkaClient(bootstrap_servers=[])
        subscriptions = SubscriptionState()
        fetcher = Fetcher(client, subscriptions)
        self.add_cleanup(fetcher.close)
        # Disable background task
        fetcher._fetch_task.cancel()
        try:
            await fetcher._fetch_task
        except asyncio.CancelledError:
            pass
        fetcher._fetch_task = create_task(asyncio.sleep(1000000))

        partition = TopicPartition('test', 0)
        offsets = {partition: OffsetAndTimestamp(12, -1)}

        async def _proc_offset_request(node_id, topic_data):
            return offsets

        fetcher._proc_offset_request = mock.Mock()
        fetcher._proc_offset_request.side_effect = _proc_offset_request

        def reset_assignment():
            subscriptions.assign_from_user({partition})
            assignment = subscriptions.subscription.assignment
            tp_state = assignment.state_value(partition)
            return assignment, tp_state

        assignment, tp_state = reset_assignment()

        self.assertIsNone(tp_state._position)

        # CASE: reset from committed
        # In basic case we will need to wait for committed
        update_task = create_task(
            fetcher._update_fetch_positions(assignment, 0, [partition]), )
        await asyncio.sleep(0.1)
        self.assertFalse(update_task.done())
        # Will continue only after committed is resolved
        tp_state.update_committed(OffsetAndMetadata(4, ""))
        needs_wakeup = await update_task
        self.assertFalse(needs_wakeup)
        self.assertEqual(tp_state._position, 4)
        self.assertEqual(fetcher._proc_offset_request.call_count, 0)

        # CASE: will not query committed if position already present
        await fetcher._update_fetch_positions(assignment, 0, [partition])
        self.assertEqual(tp_state._position, 4)
        self.assertEqual(fetcher._proc_offset_request.call_count, 0)

        # CASE: awaiting_reset for the partition
        tp_state.await_reset(OffsetResetStrategy.LATEST)
        self.assertIsNone(tp_state._position)
        await fetcher._update_fetch_positions(assignment, 0, [partition])
        self.assertEqual(tp_state._position, 12)
        self.assertEqual(fetcher._proc_offset_request.call_count, 1)

        # CASE: seeked while waiting for committed to be resolved
        assignment, tp_state = reset_assignment()
        update_task = create_task(
            fetcher._update_fetch_positions(assignment, 0, [partition]), )
        await asyncio.sleep(0.1)
        self.assertFalse(update_task.done())

        tp_state.seek(8)
        tp_state.update_committed(OffsetAndMetadata(4, ""))
        await update_task
        self.assertEqual(tp_state._position, 8)
        self.assertEqual(fetcher._proc_offset_request.call_count, 1)

        # CASE: awaiting_reset during waiting for committed
        assignment, tp_state = reset_assignment()
        update_task = create_task(
            fetcher._update_fetch_positions(assignment, 0, [partition]), )
        await asyncio.sleep(0.1)
        self.assertFalse(update_task.done())

        tp_state.await_reset(OffsetResetStrategy.LATEST)
        tp_state.update_committed(OffsetAndMetadata(4, ""))
        await update_task
        self.assertEqual(tp_state._position, 12)
        self.assertEqual(fetcher._proc_offset_request.call_count, 2)

        # CASE: reset using default strategy if committed offset undefined
        assignment, tp_state = reset_assignment()
        loop = get_running_loop()
        loop.call_later(0.01, tp_state.update_committed,
                        OffsetAndMetadata(-1, ""))
        await fetcher._update_fetch_positions(assignment, 0, [partition])
        self.assertEqual(tp_state._position, 12)
        self.assertEqual(fetcher._records, {})

        # CASE: set error if _default_reset_strategy = OffsetResetStrategy.NONE
        assignment, tp_state = reset_assignment()
        loop.call_later(0.01, tp_state.update_committed,
                        OffsetAndMetadata(-1, ""))
        fetcher._default_reset_strategy = OffsetResetStrategy.NONE
        needs_wakeup = await fetcher._update_fetch_positions(
            assignment, 0, [partition])
        self.assertTrue(needs_wakeup)
        self.assertIsNone(tp_state._position)
        self.assertIsInstance(fetcher._records[partition], FetchError)
        fetcher._records.clear()

        # CASE: if _proc_offset_request errored, we will retry on another spin
        fetcher._proc_offset_request.side_effect = UnknownError()
        assignment, tp_state = reset_assignment()
        tp_state.await_reset(OffsetResetStrategy.LATEST)
        await fetcher._update_fetch_positions(assignment, 0, [partition])
        self.assertIsNone(tp_state._position)
        self.assertTrue(tp_state.awaiting_reset)

        # CASE: reset 2 partitions separately, 1 will raise, 1 will get
        #       committed
        fetcher._proc_offset_request.side_effect = _proc_offset_request
        partition2 = TopicPartition('test', 1)
        subscriptions.assign_from_user({partition, partition2})
        assignment = subscriptions.subscription.assignment
        tp_state = assignment.state_value(partition)
        tp_state2 = assignment.state_value(partition2)
        tp_state.await_reset(OffsetResetStrategy.LATEST)
        loop.call_later(0.01, tp_state2.update_committed,
                        OffsetAndMetadata(5, ""))
        await fetcher._update_fetch_positions(assignment, 0,
                                              [partition, partition2])
        self.assertEqual(tp_state.position, 12)
        self.assertEqual(tp_state2.position, 5)
Example #11
0
    def __init__(self,
                 host,
                 port,
                 *,
                 client_id='aiokafka',
                 request_timeout_ms=40000,
                 api_version=(0, 8, 2),
                 ssl_context=None,
                 security_protocol='PLAINTEXT',
                 max_idle_ms=None,
                 on_close=None,
                 sasl_mechanism=None,
                 sasl_plain_password=None,
                 sasl_plain_username=None,
                 sasl_kerberos_service_name='kafka',
                 sasl_kerberos_domain_name=None,
                 sasl_oauth_token_provider=None,
                 version_hint=None):
        loop = get_running_loop()

        if sasl_mechanism == "GSSAPI":
            assert gssapi is not None, "gssapi library required"

        if sasl_mechanism == "OAUTHBEARER":
            if sasl_oauth_token_provider is None or \
                    not isinstance(
                        sasl_oauth_token_provider, AbstractTokenProvider):
                raise ValueError("sasl_oauth_token_provider needs to be \
                    provided implementing aiokafka.abc.AbstractTokenProvider")
            assert callable(getattr(
                sasl_oauth_token_provider, "token", None)), (
                    'sasl_oauth_token_provider must implement method #token()')

        self._loop = loop
        self._host = host
        self._port = port
        self._request_timeout = request_timeout_ms / 1000
        self._api_version = api_version
        self._client_id = client_id
        self._ssl_context = ssl_context
        self._security_protocol = security_protocol
        self._sasl_mechanism = sasl_mechanism
        self._sasl_plain_username = sasl_plain_username
        self._sasl_plain_password = sasl_plain_password
        self._sasl_kerberos_service_name = sasl_kerberos_service_name
        self._sasl_kerberos_domain_name = sasl_kerberos_domain_name
        self._sasl_oauth_token_provider = sasl_oauth_token_provider

        # Version hint is the version determined by initial client bootstrap
        self._version_hint = version_hint
        self._version_info = VersionInfo({})

        self._reader = self._writer = self._protocol = None
        # Even on small size seems to be a bit faster than list.
        # ~2x on size of 2 in Python3.6
        self._requests = collections.deque()
        self._read_task = None
        self._correlation_id = 0
        self._closed_fut = None

        self._max_idle_ms = max_idle_ms
        self._last_action = time.monotonic()
        self._idle_handle = None

        self._on_close_cb = on_close

        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))
Example #12
0
    def __init__(self,
                 *,
                 loop=None,
                 bootstrap_servers='localhost',
                 client_id=None,
                 metadata_max_age_ms=300000,
                 request_timeout_ms=40000,
                 api_version='auto',
                 acks=_missing,
                 key_serializer=None,
                 value_serializer=None,
                 compression_type=None,
                 max_batch_size=16384,
                 partitioner=DefaultPartitioner(),
                 max_request_size=1048576,
                 linger_ms=0,
                 send_backoff_ms=100,
                 retry_backoff_ms=100,
                 security_protocol="PLAINTEXT",
                 ssl_context=None,
                 connections_max_idle_ms=540000,
                 enable_idempotence=False,
                 transactional_id=None,
                 transaction_timeout_ms=60000,
                 sasl_mechanism="PLAIN",
                 sasl_plain_password=None,
                 sasl_plain_username=None,
                 sasl_kerberos_service_name='kafka',
                 sasl_kerberos_domain_name=None):
        if loop is None:
            loop = get_running_loop()

        if acks not in (0, 1, -1, 'all', _missing):
            raise ValueError("Invalid ACKS parameter")
        if compression_type not in ('gzip', 'snappy', 'lz4', None):
            raise ValueError("Invalid compression type!")
        if compression_type:
            checker, compression_attrs = self._COMPRESSORS[compression_type]
            if not checker():
                raise RuntimeError(
                    "Compression library for {} not found".format(
                        compression_type))
        else:
            compression_attrs = 0

        if transactional_id is not None:
            enable_idempotence = True
        else:
            transaction_timeout_ms = INTEGER_MAX_VALUE

        if enable_idempotence:
            if acks is _missing:
                acks = -1
            elif acks not in ('all', -1):
                raise ValueError(
                    "acks={} not supported if enable_idempotence=True".format(
                        acks))
            self._txn_manager = TransactionManager(transactional_id,
                                                   transaction_timeout_ms,
                                                   loop=loop)
        else:
            self._txn_manager = None

        if acks is _missing:
            acks = 1
        elif acks == 'all':
            acks = -1

        AIOKafkaProducer._PRODUCER_CLIENT_ID_SEQUENCE += 1
        if client_id is None:
            client_id = 'aiokafka-producer-%s' % \
                AIOKafkaProducer._PRODUCER_CLIENT_ID_SEQUENCE

        self._key_serializer = key_serializer
        self._value_serializer = value_serializer
        self._compression_type = compression_type
        self._partitioner = partitioner
        self._max_request_size = max_request_size
        self._request_timeout_ms = request_timeout_ms

        self.client = AIOKafkaClient(
            loop=loop,
            bootstrap_servers=bootstrap_servers,
            client_id=client_id,
            metadata_max_age_ms=metadata_max_age_ms,
            request_timeout_ms=request_timeout_ms,
            retry_backoff_ms=retry_backoff_ms,
            api_version=api_version,
            security_protocol=security_protocol,
            ssl_context=ssl_context,
            connections_max_idle_ms=connections_max_idle_ms,
            sasl_mechanism=sasl_mechanism,
            sasl_plain_username=sasl_plain_username,
            sasl_plain_password=sasl_plain_password,
            sasl_kerberos_service_name=sasl_kerberos_service_name,
            sasl_kerberos_domain_name=sasl_kerberos_domain_name)
        self._metadata = self.client.cluster
        self._message_accumulator = MessageAccumulator(
            self._metadata,
            max_batch_size,
            compression_attrs,
            self._request_timeout_ms / 1000,
            txn_manager=self._txn_manager,
            loop=loop)
        self._sender = Sender(self.client,
                              acks=acks,
                              txn_manager=self._txn_manager,
                              retry_backoff_ms=retry_backoff_ms,
                              linger_ms=linger_ms,
                              message_accumulator=self._message_accumulator,
                              request_timeout_ms=request_timeout_ms,
                              loop=loop)

        self._loop = loop
        if loop.get_debug():
            self._source_traceback = traceback.extract_stack(sys._getframe(1))
        self._closed = False
Example #13
0
    def __init__(self,
                 *,
                 loop=None,
                 bootstrap_servers='localhost',
                 client_id='aiokafka-' + __version__,
                 metadata_max_age_ms=300000,
                 request_timeout_ms=40000,
                 retry_backoff_ms=100,
                 ssl_context=None,
                 security_protocol='PLAINTEXT',
                 api_version='auto',
                 connections_max_idle_ms=540000,
                 sasl_mechanism='PLAIN',
                 sasl_plain_username=None,
                 sasl_plain_password=None,
                 sasl_kerberos_service_name='kafka',
                 sasl_kerberos_domain_name=None,
                 sasl_oauth_token_provider=None):
        if loop is None:
            loop = get_running_loop()

        if security_protocol not in ('SSL', 'PLAINTEXT', 'SASL_PLAINTEXT',
                                     'SASL_SSL'):
            raise ValueError("`security_protocol` should be SSL or PLAINTEXT")
        if security_protocol in ["SSL", "SASL_SSL"] and ssl_context is None:
            raise ValueError(
                "`ssl_context` is mandatory if security_protocol=='SSL'")
        if security_protocol in ["SASL_SSL", "SASL_PLAINTEXT"]:
            if sasl_mechanism not in ("PLAIN", "GSSAPI", "SCRAM-SHA-256",
                                      "SCRAM-SHA-512", "OAUTHBEARER"):
                raise ValueError("only `PLAIN`, `GSSAPI`, `SCRAM-SHA-256`, "
                                 "`SCRAM-SHA-512` and `OAUTHBEARER`"
                                 "sasl_mechanism are supported "
                                 "at the moment")
            if sasl_mechanism == "PLAIN" and \
               (sasl_plain_username is None or sasl_plain_password is None):
                raise ValueError(
                    "sasl_plain_username and sasl_plain_password required for "
                    "PLAIN sasl")

        self._bootstrap_servers = bootstrap_servers
        self._client_id = client_id
        self._metadata_max_age_ms = metadata_max_age_ms
        self._request_timeout_ms = request_timeout_ms
        if api_version != "auto":
            api_version = parse_kafka_version(api_version)
        self._api_version = api_version
        self._security_protocol = security_protocol
        self._ssl_context = ssl_context
        self._retry_backoff = retry_backoff_ms / 1000
        self._connections_max_idle_ms = connections_max_idle_ms
        self._sasl_mechanism = sasl_mechanism
        self._sasl_plain_username = sasl_plain_username
        self._sasl_plain_password = sasl_plain_password
        self._sasl_kerberos_service_name = sasl_kerberos_service_name
        self._sasl_kerberos_domain_name = sasl_kerberos_domain_name
        self._sasl_oauth_token_provider = sasl_oauth_token_provider

        self.cluster = ClusterMetadata(metadata_max_age_ms=metadata_max_age_ms)

        self._topics = set()  # empty set will fetch all topic metadata
        self._conns = {}
        self._loop = loop
        self._sync_task = None

        self._md_update_fut = None
        self._md_update_waiter = loop.create_future()
        self._get_conn_lock_value = None
Example #14
0
    async def bootstrap(self):
        """Try to to bootstrap initial cluster metadata"""
        assert self._loop is get_running_loop(), (
            "Please create objects with the same loop as running with")
        # using request v0 for bootstrap if not sure v1 is available
        if self._api_version == "auto" or self._api_version < (0, 10):
            metadata_request = MetadataRequest[0]([])
        else:
            metadata_request = MetadataRequest[1]([])

        version_hint = None
        if self._api_version != "auto":
            version_hint = self._api_version

        for host, port, _ in self.hosts:
            log.debug("Attempting to bootstrap via node at %s:%s", host, port)

            try:
                bootstrap_conn = await create_conn(
                    host,
                    port,
                    client_id=self._client_id,
                    request_timeout_ms=self._request_timeout_ms,
                    ssl_context=self._ssl_context,
                    security_protocol=self._security_protocol,
                    max_idle_ms=self._connections_max_idle_ms,
                    sasl_mechanism=self._sasl_mechanism,
                    sasl_plain_username=self._sasl_plain_username,
                    sasl_plain_password=self._sasl_plain_password,
                    sasl_kerberos_service_name=self.
                    _sasl_kerberos_service_name,  # noqa: ignore=E501
                    sasl_kerberos_domain_name=self._sasl_kerberos_domain_name,
                    sasl_oauth_token_provider=self._sasl_oauth_token_provider,
                    version_hint=version_hint)
            except (OSError, asyncio.TimeoutError) as err:
                log.error('Unable connect to "%s:%s": %s', host, port, err)
                continue

            try:
                metadata = await bootstrap_conn.send(metadata_request)
            except (KafkaError, asyncio.TimeoutError) as err:
                log.warning('Unable to request metadata from "%s:%s": %s',
                            host, port, err)
                bootstrap_conn.close()
                continue

            self.cluster.update_metadata(metadata)

            # A cluster with no topics can return no broker metadata...
            # In that case, we should keep the bootstrap connection till
            # we get a normal cluster layout.
            if not len(self.cluster.brokers()):
                bootstrap_id = ('bootstrap', ConnectionGroup.DEFAULT)
                self._conns[bootstrap_id] = bootstrap_conn
            else:
                bootstrap_conn.close()

            log.debug('Received cluster metadata: %s', self.cluster)
            break
        else:
            raise KafkaConnectionError('Unable to bootstrap from {}'.format(
                self.hosts))

        # detect api version if need
        if self._api_version == 'auto':
            self._api_version = await self.check_version()

        if self._sync_task is None:
            # starting metadata synchronizer task
            self._sync_task = create_task(self._md_synchronizer())