Exemple #1
0
    def __init__(self,
                 app: AppT,
                 tables: TableManagerT,
                 **kwargs: Any) -> None:
        self.app = app
        self.tables = cast(_TableManager, tables)

        self.standby_tps = set()
        self.active_tps = set()

        self.tp_to_table = {}
        self.active_offsets = Counter()
        self.standby_offsets = Counter()

        self.active_highwaters = Counter()
        self.standby_highwaters = Counter()
        self.completed = Event()

        self.buffers = defaultdict(list)
        self.buffer_sizes = {}
        self.recovery_delay = self.app.conf.stream_recovery_delay

        self.actives_for_table = defaultdict(set)
        self.standbys_for_table = defaultdict(set)

        super().__init__(**kwargs)
Exemple #2
0
async def test_wait__multiple_events():
    async with X() as service:
        event1 = Event()
        event2 = Event()

        async def loser():
            await asyncio.sleep(1)
            event1.set()

        async def winner():
            await asyncio.sleep(0.1)
            event2.set()

        fut1 = asyncio.ensure_future(loser())
        try:
            fut2 = asyncio.ensure_future(winner())
            try:
                result = await service.wait_first(event1, event2)
                assert not result.stopped
                assert event2 in result.done
                assert event1 not in result.done
            finally:
                fut2.cancel()
        finally:
            fut1.cancel()
Exemple #3
0
 def on_init(self) -> None:
     app = self.transport.app
     transport = cast(Transport, self.transport)
     self._rebalance_listener = self.RebalanceListener(self)
     if app.client_only:
         self._consumer = self._create_client_consumer(app, transport)
     else:
         self._consumer = self._create_worker_consumer(app, transport)
     self._active_partitions = None
     self._paused_partitions = set()
     self.can_resume_flow = Event()
Exemple #4
0
 def __init__(self,
              transport: TransportT,
              callback: ConsumerCallback,
              on_partitions_revoked: PartitionsRevokedCallback,
              on_partitions_assigned: PartitionsAssignedCallback,
              *,
              commit_interval: float = None,
              commit_livelock_soft_timeout: float = None,
              loop: asyncio.AbstractEventLoop = None,
              **kwargs: Any) -> None:
     assert callback is not None
     self.transport = transport
     self.app = self.transport.app
     self.in_transaction = self.app.in_transaction
     self.callback = callback
     self._on_message_in = self.app.sensors.on_message_in
     self._on_partitions_revoked = on_partitions_revoked
     self._on_partitions_assigned = on_partitions_assigned
     self._commit_every = self.app.conf.broker_commit_every
     self.scheduler = self.app.conf.ConsumerScheduler()
     self.commit_interval = (commit_interval
                             or self.app.conf.broker_commit_interval)
     self.commit_livelock_soft_timeout = (
         commit_livelock_soft_timeout
         or self.app.conf.broker_commit_livelock_soft_timeout)
     self._gap = defaultdict(list)
     self._acked = defaultdict(list)
     self._acked_index = defaultdict(set)
     self._read_offset = defaultdict(lambda: None)
     self._committed_offset = defaultdict(lambda: None)
     self._unacked_messages = WeakSet()
     self._waiting_for_ack = None
     self._time_start = monotonic()
     self._last_batch = Counter()
     self._end_offset_monitor_interval = self.commit_interval * 2
     self.randomly_assigned_topics = set()
     self.can_resume_flow = Event()
     self._reset_state()
     super().__init__(loop=loop or self.transport.loop, **kwargs)
     self.transactions = self.transport.create_transaction_manager(
         consumer=self,
         producer=self.app.producer,
         beacon=self.beacon,
         loop=self.loop,
     )
Exemple #5
0
class Consumer(Service, ConsumerT):
    """Base Consumer."""

    app: AppT

    logger = logger

    #: Tuple of exception types that may be raised when the
    #: underlying consumer driver is stopped.
    consumer_stopped_errors: ClassVar[Tuple[Type[BaseException], ...]] = ()

    # Mapping of TP to list of gap in offsets.
    _gap: MutableMapping[TP, List[int]]

    # Mapping of TP to list of acked offsets.
    _acked: MutableMapping[TP, List[int]]

    #: Fast lookup to see if tp+offset was acked.
    _acked_index: MutableMapping[TP, Set[int]]

    #: Keeps track of the currently read offset in each TP
    _read_offset: MutableMapping[TP, Optional[int]]

    #: Keeps track of the currently committed offset in each TP.
    _committed_offset: MutableMapping[TP, Optional[int]]

    #: The consumer.wait_empty() method will set this to be notified
    #: when something acks a message.
    _waiting_for_ack: Optional[asyncio.Future] = None

    #: Used by .commit to ensure only one thread is comitting at a time.
    #: Other thread starting to commit while a commit is already active,
    #: will wait for the original request to finish, and do nothing.
    _commit_fut: Optional[asyncio.Future] = None

    #: Set of unacked messages: that is messages that we started processing
    #: and that we MUST attempt to complete processing of, before
    #: shutting down or resuming a rebalance.
    _unacked_messages: MutableSet[Message]

    #: Time of last record batch received.
    #: Set only when not set, and reset by commit() so actually
    #: tracks how long it ago it was since we received a record that
    #: was never committed.
    _last_batch: Counter[TP]

    #: Time of when the consumer was started.
    _time_start: float

    # How often to poll and track log end offsets.
    _end_offset_monitor_interval: float

    _commit_every: Optional[int]
    _n_acked: int = 0

    _active_partitions: Optional[Set[TP]]
    _paused_partitions: Set[TP]

    flow_active: bool = True
    can_resume_flow: Event

    def __init__(self,
                 transport: TransportT,
                 callback: ConsumerCallback,
                 on_partitions_revoked: PartitionsRevokedCallback,
                 on_partitions_assigned: PartitionsAssignedCallback,
                 *,
                 commit_interval: float = None,
                 commit_livelock_soft_timeout: float = None,
                 loop: asyncio.AbstractEventLoop = None,
                 **kwargs: Any) -> None:
        assert callback is not None
        self.transport = transport
        self.app = self.transport.app
        self.in_transaction = self.app.in_transaction
        self.callback = callback
        self._on_message_in = self.app.sensors.on_message_in
        self._on_partitions_revoked = on_partitions_revoked
        self._on_partitions_assigned = on_partitions_assigned
        self._commit_every = self.app.conf.broker_commit_every
        self.scheduler = self.app.conf.ConsumerScheduler()
        self.commit_interval = (commit_interval
                                or self.app.conf.broker_commit_interval)
        self.commit_livelock_soft_timeout = (
            commit_livelock_soft_timeout
            or self.app.conf.broker_commit_livelock_soft_timeout)
        self._gap = defaultdict(list)
        self._acked = defaultdict(list)
        self._acked_index = defaultdict(set)
        self._read_offset = defaultdict(lambda: None)
        self._committed_offset = defaultdict(lambda: None)
        self._unacked_messages = WeakSet()
        self._waiting_for_ack = None
        self._time_start = monotonic()
        self._last_batch = Counter()
        self._end_offset_monitor_interval = self.commit_interval * 2
        self.randomly_assigned_topics = set()
        self.can_resume_flow = Event()
        self._reset_state()
        super().__init__(loop=loop or self.transport.loop, **kwargs)
        self.transactions = self.transport.create_transaction_manager(
            consumer=self,
            producer=self.app.producer,
            beacon=self.beacon,
            loop=self.loop,
        )

    def on_init_dependencies(self) -> Iterable[ServiceT]:
        """Return list of services this consumer depends on."""
        # We start the TransactionManager only if
        # processing_guarantee='exactly_once'
        if self.in_transaction:
            return [self.transactions]
        return []

    def _reset_state(self) -> None:
        self._active_partitions = None
        self._paused_partitions = set()
        self.can_resume_flow.clear()
        self.flow_active = True
        self._last_batch.clear()
        self._time_start = monotonic()

    async def on_restart(self) -> None:
        """Call when the consumer is restarted."""
        self._reset_state()
        self.on_init()

    def _get_active_partitions(self) -> Set[TP]:
        tps = self._active_partitions
        if tps is None:
            return self._set_active_tps(self.assignment())
        assert all(isinstance(x, TP) for x in tps)
        return tps

    def _set_active_tps(self, tps: Set[TP]) -> Set[TP]:
        xtps = self._active_partitions = ensure_TPset(tps)  # copy
        xtps.difference_update(self._paused_partitions)
        return xtps

    @abc.abstractmethod
    async def _commit(self, offsets: Mapping[TP,
                                             int]) -> bool:  # pragma: no cover
        ...

    async def perform_seek(self) -> None:
        """Seek all partitions to their current committed position."""
        read_offset = self._read_offset
        _committed_offsets = await self.seek_to_committed()
        read_offset.update({
            tp: offset if offset is not None and offset >= 0 else None
            for tp, offset in _committed_offsets.items()
        })
        committed_offsets = {
            ensure_TP(tp): offset if offset else None
            for tp, offset in _committed_offsets.items() if offset is not None
        }
        self._committed_offset.update(committed_offsets)

    @abc.abstractmethod
    async def seek_to_committed(self) -> Mapping[TP, int]:
        """Seek all partitions to their committed offsets."""
        ...

    async def seek(self, partition: TP, offset: int) -> None:
        """Seek partition to specific offset."""
        self.log.dev('SEEK %r -> %r', partition, offset)
        # reset livelock detection
        self._last_batch.pop(partition, None)
        await self._seek(partition, offset)
        # set new read offset so we will reread messages
        self._read_offset[ensure_TP(partition)] = offset if offset else None

    @abc.abstractmethod
    async def _seek(self, partition: TP, offset: int) -> None:
        ...

    def stop_flow(self) -> None:
        """Block consumer from processing any more messages."""
        self.flow_active = False
        self.can_resume_flow.clear()

    def resume_flow(self) -> None:
        """Allow consumer to process messages."""
        self.flow_active = True
        self.can_resume_flow.set()

    def pause_partitions(self, tps: Iterable[TP]) -> None:
        """Pause fetching from partitions."""
        tpset = ensure_TPset(tps)
        self._get_active_partitions().difference_update(tpset)
        self._paused_partitions.update(tpset)

    def resume_partitions(self, tps: Iterable[TP]) -> None:
        """Resume fetching from partitions."""
        tpset = ensure_TPset(tps)
        self._get_active_partitions().update(tps)
        self._paused_partitions.difference_update(tpset)

    @abc.abstractmethod
    def _new_topicpartition(self, topic: str,
                            partition: int) -> TP:  # pragma: no cover
        ...

    def _is_changelog_tp(self, tp: TP) -> bool:
        return tp.topic in self.app.tables.changelog_topics

    @Service.transitions_to(CONSUMER_PARTITIONS_REVOKED)
    async def on_partitions_revoked(self, revoked: Set[TP]) -> None:
        """Call during rebalancing when partitions are being revoked."""
        # NOTE:
        # The ConsumerRebalanceListener is responsible for calling
        # app.on_rebalance_start(), and this must have happened
        # before we get to this point (see aiokafka implementation).
        span = self.app._start_span_from_rebalancing('on_partitions_revoked')
        T = traced_from_parent_span(span)
        with span:
            # see comment in on_partitions_assigned
            # remove revoked partitions from active + paused tps.
            if self._active_partitions is not None:
                self._active_partitions.difference_update(revoked)
            self._paused_partitions.difference_update(revoked)
            await T(self._on_partitions_revoked, partitions=revoked)(revoked)

    @Service.transitions_to(CONSUMER_PARTITIONS_ASSIGNED)
    async def on_partitions_assigned(self, assigned: Set[TP]) -> None:
        """Call during rebalancing when partitions are being assigned."""
        span = self.app._start_span_from_rebalancing('on_partitions_assigned')
        T = traced_from_parent_span(span)
        with span:
            # remove recently revoked tps from set of paused tps.
            self._paused_partitions.intersection_update(assigned)
            # cache set of assigned partitions
            self._set_active_tps(assigned)
            # start callback chain of assigned callbacks.
            #   need to copy set at this point, since we cannot have
            #   the callbacks mutate our active list.
            self._last_batch.clear()
            await T(self._on_partitions_assigned,
                    partitions=assigned)(assigned)
        self.app.on_rebalance_return()

    @abc.abstractmethod
    async def _getmany(self, active_partitions: Optional[Set[TP]],
                       timeout: float) -> RecordMap:
        ...

    async def getmany(self,
                      timeout: float) -> AsyncIterator[Tuple[TP, Message]]:
        """Fetch batch of messages from server."""
        # records' contain mapping from TP to list of messages.
        # if there are two agents, consuming from topics t1 and t2,
        # normal order of iteration would be to process each
        # tp in the dict:
        #    for tp. messages in records.items():
        #        for message in messages:
        #           yield tp, message
        #
        # The problem with this, is if we have prefetched 16k records
        # for one partition, the other partitions won't even start processing
        # before those 16k records are completed.
        #
        # So we try round-robin between the tps instead:
        #
        #    iterators: Dict[TP, Iterator] = {
        #        tp: iter(messages)
        #        for tp, messages in records.items()
        #    }
        #    while iterators:
        #        for tp, messages in iterators.items():
        #            yield tp, next(messages)
        #            # remove from iterators if empty.
        #
        # The problem with this implementation is that
        # the records mapping is ordered by TP, so records.keys()
        # will look like this:
        #
        #  TP(topic='bar', partition=0)
        #  TP(topic='bar', partition=1)
        #  TP(topic='bar', partition=2)
        #  TP(topic='bar', partition=3)
        #  TP(topic='foo', partition=0)
        #  TP(topic='foo', partition=1)
        #  TP(topic='foo', partition=2)
        #  TP(topic='foo', partition=3)
        #
        # If there are 100 partitions for each topic,
        # it will process 100 items in the first topic, then 100 items
        # in the other topic, but even worse if partition counts
        # vary greatly, t1 has 1000 partitions and t2
        # has 1 partition, then t2 will end up being starved most of the time.
        #
        # We solve this by going round-robin through each topic.
        records, active_partitions = await self._wait_next_records(timeout)
        if records is None or self.should_stop:
            return

        records_it = self.scheduler.iterate(records)
        to_message = self._to_message  # localize
        if self.flow_active:
            for tp, record in records_it:
                if not self.flow_active:
                    break
                if active_partitions is None or tp in active_partitions:
                    highwater_mark = self.highwater(tp)
                    self.app.monitor.track_tp_end_offset(tp, highwater_mark)
                    # convert timestamp to seconds from int milliseconds.
                    yield tp, to_message(tp, record)

    async def _wait_next_records(
            self,
            timeout: float) -> Tuple[Optional[RecordMap], Optional[Set[TP]]]:
        if not self.flow_active:
            await self.wait(self.can_resume_flow)
        # Implementation for the Fetcher service.

        is_client_only = self.app.client_only

        active_partitions: Optional[Set[TP]]
        if is_client_only:
            active_partitions = None
        else:
            active_partitions = self._get_active_partitions()

        records: RecordMap = {}
        if is_client_only or active_partitions:
            # Fetch records only if active partitions to avoid the risk of
            # fetching all partitions in the beginning when none of the
            # partitions is paused/resumed.
            records = await self._getmany(
                active_partitions=active_partitions,
                timeout=timeout,
            )
        else:
            # We should still release to the event loop
            await self.sleep(1)
        return records, active_partitions

    @abc.abstractmethod
    def _to_message(self, tp: TP, record: Any) -> ConsumerMessage:
        ...

    def track_message(self, message: Message) -> None:
        """Track message and mark it as pending ack."""
        # add to set of pending messages that must be acked for graceful
        # shutdown.  This is called by transport.Conductor,
        # before delivering messages to streams.
        self._unacked_messages.add(message)
        # call sensors
        self._on_message_in(message.tp, message.offset, message)

    def ack(self, message: Message) -> bool:
        """Mark message as being acknowledged by stream."""
        if not message.acked:
            message.acked = True
            tp = message.tp
            offset = message.offset
            if self.app.topics.acks_enabled_for(message.topic):
                committed = self._committed_offset[tp]
                try:
                    if committed is None or offset > committed:
                        acked_index = self._acked_index[tp]
                        if offset not in acked_index:
                            self._unacked_messages.discard(message)
                            acked_index.add(offset)
                            acked_for_tp = self._acked[tp]
                            acked_for_tp.append(offset)
                            self._n_acked += 1
                            return True
                finally:
                    notify(self._waiting_for_ack)
        return False

    async def _wait_for_ack(self, timeout: float) -> None:
        # arm future so that `ack()` can wake us up
        self._waiting_for_ack = asyncio.Future(loop=self.loop)
        try:
            # wait for `ack()` to wake us up
            await asyncio.wait_for(self._waiting_for_ack,
                                   loop=self.loop,
                                   timeout=1)
        except (asyncio.TimeoutError,
                asyncio.CancelledError):  # pragma: no cover
            pass
        finally:
            self._waiting_for_ack = None

    @Service.transitions_to(CONSUMER_WAIT_EMPTY)
    async def wait_empty(self) -> None:
        """Wait for all messages that started processing to be acked."""
        wait_count = 0
        T = traced_from_parent_span()
        while not self.should_stop and self._unacked_messages:
            wait_count += 1
            if not wait_count % 10:  # pragma: no cover
                remaining = [(m.refcount, m) for m in self._unacked_messages]
                self.log.warning('wait_empty: Waiting for %r tasks', remaining)
            self.log.dev('STILL WAITING FOR ALL STREAMS TO FINISH')
            self.log.dev('WAITING FOR %r EVENTS', len(self._unacked_messages))
            gc.collect()
            await T(self.commit)()
            if not self._unacked_messages:
                break
            await T(self._wait_for_ack)(timeout=1)

        self.log.dev('COMMITTING AGAIN AFTER STREAMS DONE')
        await T(self.commit_and_end_transactions)()

    async def commit_and_end_transactions(self) -> None:
        """Commit all safe offsets and end transaction."""
        await self.commit(start_new_transaction=False)

    async def on_stop(self) -> None:
        """Call when consumer is stopping."""
        if self.app.conf.stream_wait_empty:
            await self.wait_empty()
        else:
            await self.commit_and_end_transactions()

        self._last_batch.clear()

    @Service.task
    async def _commit_handler(self) -> None:
        interval = self.commit_interval

        await self.sleep(interval)
        async for sleep_time in self.itertimer(interval, name='commit'):
            await self.commit()

    @Service.task
    async def _commit_livelock_detector(self) -> None:  # pragma: no cover
        soft_timeout = self.commit_livelock_soft_timeout
        interval: float = self.commit_interval * 2.5
        acks_enabled_for = self.app.topics.acks_enabled_for
        await self.sleep(interval)
        async for sleep_time in self.itertimer(interval, name='livelock'):
            for tp, last_batch_time in self._last_batch.items():
                if last_batch_time and acks_enabled_for(tp.topic):
                    s_since_batch = monotonic() - last_batch_time
                    if s_since_batch > soft_timeout:
                        self.log.warning(
                            'Possible livelock: '
                            'COMMIT OFFSET NOT ADVANCING FOR %r', tp)

    async def commit(self,
                     topics: TPorTopicSet = None,
                     start_new_transaction: bool = True) -> bool:
        """Maybe commit the offset for all or specific topics.

        Arguments:
            topics: Set containing topics and/or TopicPartitions to commit.
        """
        if self.app.client_only:
            # client only cannot commit as consumer does not have group_id
            return False
        if await self.maybe_wait_for_commit_to_finish():
            # original commit finished, return False as we did not commit
            return False

        self._commit_fut = asyncio.Future(loop=self.loop)
        try:
            return await self.force_commit(
                topics,
                start_new_transaction=start_new_transaction,
            )
        finally:
            # set commit_fut to None so that next call will commit.
            fut, self._commit_fut = self._commit_fut, None
            # notify followers that the commit is done.
            notify(fut)

    async def maybe_wait_for_commit_to_finish(self) -> bool:
        """Wait for any existing commit operation to finish."""
        # Only one coroutine allowed to commit at a time,
        # and other coroutines should wait for the original commit to finish
        # then do nothing.
        if self._commit_fut is not None:
            # something is already committing so wait for that future.
            try:
                await self._commit_fut
            except asyncio.CancelledError:
                # if future is cancelled we have to start new commit
                pass
            else:
                return True
        return False

    @Service.transitions_to(CONSUMER_COMMITTING)
    async def force_commit(self,
                           topics: TPorTopicSet = None,
                           start_new_transaction: bool = True) -> bool:
        """Force offset commit."""
        sensor_state = self.app.sensors.on_commit_initiated(self)

        # Go over the ack list in each topic/partition
        commit_tps = list(self._filter_tps_with_pending_acks(topics))
        did_commit = await self._commit_tps(
            commit_tps, start_new_transaction=start_new_transaction)

        self.app.sensors.on_commit_completed(self, sensor_state)
        return did_commit

    async def _commit_tps(self, tps: Iterable[TP],
                          start_new_transaction: bool) -> bool:
        commit_offsets = self._filter_committable_offsets(tps)
        if commit_offsets:
            try:
                # send all messages attached to the new offset
                await self._handle_attached(commit_offsets)
            except ProducerSendError as exc:
                await self.crash(exc)
            else:
                return await self._commit_offsets(
                    commit_offsets,
                    start_new_transaction=start_new_transaction)
        return False

    def _filter_committable_offsets(self, tps: Iterable[TP]) -> Dict[TP, int]:
        commit_offsets = {}
        for tp in tps:
            # Find the latest offset we can commit in this tp
            offset = self._new_offset(tp)
            # check if we can commit to this offset
            if offset is not None and self._should_commit(tp, offset):
                commit_offsets[tp] = offset
        return commit_offsets

    async def _handle_attached(self, commit_offsets: Mapping[TP, int]) -> None:
        for tp, offset in commit_offsets.items():
            app = cast(_App, self.app)
            attachments = app._attachments
            producer = app.producer
            # Start publishing the messages and return a list of pending
            # futures.
            pending = await attachments.publish_for_tp_offset(tp, offset)
            # then we wait for either
            #  1) all the attached messages to be published, or
            #  2) the producer crashing
            #
            # If the producer crashes we will not be able to send any messages
            # and it only crashes when there's an irrecoverable error.
            #
            # If we cannot commit it means the events will be processed again,
            # so conforms to at-least-once semantics.
            if pending:
                await producer.wait_many(pending)

    async def _commit_offsets(self,
                              offsets: Mapping[TP, int],
                              start_new_transaction: bool = True) -> bool:
        table = terminal.logtable(
            [(str(tp), str(offset)) for tp, offset in offsets.items()],
            title='Commit Offsets',
            headers=['TP', 'Offset'],
        )
        self.log.dev('COMMITTING OFFSETS:\n%s', table)
        assignment = self.assignment()
        committable_offsets: Dict[TP, int] = {}
        revoked: Dict[TP, int] = {}
        for tp, offset in offsets.items():
            if tp in assignment:
                committable_offsets[tp] = offset
            else:
                revoked[tp] = offset
        if revoked:
            self.log.info(
                'Discarded commit for revoked partitions that '
                'will be eventually processed again: %r',
                revoked,
            )
        if not committable_offsets:
            return False
        with flight_recorder(self.log, timeout=300.0) as on_timeout:
            did_commit = False
            on_timeout.info('+consumer.commit()')
            if self.in_transaction:
                did_commit = await self.transactions.commit(
                    committable_offsets,
                    start_new_transaction=start_new_transaction,
                )
            else:
                did_commit = await self._commit(committable_offsets)
            on_timeout.info('-consumer.commit()')
            if did_commit:
                on_timeout.info('+tables.on_commit')
                self.app.tables.on_commit(committable_offsets)
                on_timeout.info('-tables.on_commit')
        self._committed_offset.update(committable_offsets)
        self.app.monitor.on_tp_commit(committable_offsets)
        for tp in offsets:
            self._last_batch.pop(tp, None)
        return did_commit

    def _filter_tps_with_pending_acks(self,
                                      topics: TPorTopicSet = None
                                      ) -> Iterator[TP]:
        return (tp for tp in self._acked
                if topics is None or tp in topics or tp.topic in topics)

    def _should_commit(self, tp: TP, offset: int) -> bool:
        committed = self._committed_offset[tp]
        return committed is None or bool(offset) and offset > committed

    def _new_offset(self, tp: TP) -> Optional[int]:
        # get the new offset for this tp, by going through
        # its list of acked messages.
        acked = self._acked[tp]

        # We iterate over it until we find a gap
        # then return the offset before that.
        # For example if acked[tp] is:
        #   1 2 3 4 5 6 7 8 9
        # the return value will be: 9
        # If acked[tp] is:
        #  34 35 36 40 41 42 43 44
        #          ^--- gap
        # the return value will be: 36
        if acked:
            max_offset = max(acked)
            gap_for_tp = self._gap[tp]
            if gap_for_tp:
                gap_index = next(
                    (i for i, x in enumerate(gap_for_tp) if x > max_offset),
                    len(gap_for_tp))
                gaps = gap_for_tp[:gap_index]
                acked.extend(gaps)
                gap_for_tp[:gap_index] = []
            acked.sort()
            # Note: acked is always kept sorted.
            # find first list of consecutive numbers
            batch = next(consecutive_numbers(acked))
            # remove them from the list to clean up.
            acked[:len(batch) - 1] = []
            self._acked_index[tp].difference_update(batch)
            # return the highest commit offset
            return batch[-1]
        return None

    async def on_task_error(self, exc: BaseException) -> None:
        """Call when processing a message failed."""
        await self.commit()

    def _add_gap(self, tp: TP, offset_from: int, offset_to: int) -> None:
        committed = self._committed_offset[tp]
        gap_for_tp = self._gap[tp]
        for offset in range(offset_from, offset_to):
            if committed is None or offset > committed:
                gap_for_tp.append(offset)

    async def _drain_messages(self,
                              fetcher: ServiceT) -> None:  # pragma: no cover
        # This is the background thread started by Fetcher, used to
        # constantly read messages using Consumer.getmany.
        # It takes Fetcher as argument, because we must be able to
        # stop it using `await Fetcher.stop()`.
        callback = self.callback
        getmany = self.getmany
        consumer_should_stop = self._stopped.is_set
        fetcher_should_stop = fetcher._stopped.is_set

        get_read_offset = self._read_offset.__getitem__
        set_read_offset = self._read_offset.__setitem__
        get_commit_offset = self._committed_offset.__getitem__
        flag_consumer_fetching = CONSUMER_FETCHING
        set_flag = self.diag.set_flag
        unset_flag = self.diag.unset_flag
        commit_every = self._commit_every
        acks_enabled_for = self.app.topics.acks_enabled_for

        try:
            while not (consumer_should_stop() or fetcher_should_stop()):
                set_flag(flag_consumer_fetching)
                ait = cast(AsyncIterator, getmany(timeout=5.0))
                last_batch = self._last_batch

                # Sleeping because sometimes getmany is called in a loop
                # never releasing to the event loop
                await self.sleep(0)
                if not self.should_stop:
                    async for tp, message in ait:
                        offset = message.offset
                        r_offset = get_read_offset(tp)
                        committed_offset = get_commit_offset(tp)
                        if committed_offset != r_offset:
                            last_batch[tp] = monotonic()
                        if r_offset is None or offset > r_offset:
                            gap = offset - (r_offset or 0)
                            # We have a gap in income messages
                            if gap > 1 and r_offset:
                                acks_enabled = acks_enabled_for(message.topic)
                                if acks_enabled:
                                    self._add_gap(tp, r_offset + 1, offset)
                            if commit_every is not None:
                                if self._n_acked >= commit_every:
                                    self._n_acked = 0
                                    await self.commit()
                            await callback(message)
                            set_read_offset(tp, offset)
                        else:
                            self.log.dev('DROPPED MESSAGE ROFF %r: k=%r v=%r',
                                         offset, message.key, message.value)
                    unset_flag(flag_consumer_fetching)

        except self.consumer_stopped_errors:
            if self.transport.app.should_stop:
                # we're already stopping so ignore
                self.log.info('Broker stopped consumer, shutting down...')
                return
            raise
        except asyncio.CancelledError:
            if self.transport.app.should_stop:
                # we're already stopping so ignore
                self.log.info('Consumer shutting down for user cancel.')
                return
            raise
        except Exception as exc:
            self.log.exception('Drain messages raised: %r', exc)
            raise
        finally:
            unset_flag(flag_consumer_fetching)

    def close(self) -> None:
        """Close consumer for graceful shutdown."""
        ...

    @property
    def unacked(self) -> Set[Message]:
        """Return the set of currently unacknowledged messages."""
        return cast(Set[Message], self._unacked_messages)
Exemple #6
0
 def __init__(self, app: AppT, **kwargs: Any) -> None:
     self.app = app
     self.data = OrderedDict()
     self._by_topic = defaultdict(WeakSet)
     self._agents_started = Event()
     Service.__init__(self, **kwargs)
Exemple #7
0
class AgentManager(Service, AgentManagerT, ManagedUserDict):
    """Agent manager."""

    _by_topic: MutableMapping[str, MutableSet[AgentT]]

    def __init__(self, app: AppT, **kwargs: Any) -> None:
        self.app = app
        self.data = OrderedDict()
        self._by_topic = defaultdict(WeakSet)
        self._agents_started = Event()
        Service.__init__(self, **kwargs)

    def __hash__(self) -> int:
        return object.__hash__(self)

    async def on_start(self) -> None:
        """Call when agents are being started."""
        self.update_topic_index()
        for agent in self.values():
            await agent.maybe_start()
        self._agents_started.set()

    async def wait_until_agents_started(self) -> None:
        await self.wait_for_stopped(self._agents_started)

    def service_reset(self) -> None:
        """Reset service state on restart."""
        [agent.service_reset() for agent in self.values()]
        super().service_reset()

    async def on_stop(self) -> None:
        """Call when agents are being stopped."""
        for agent in self.values():
            try:
                await asyncio.shield(agent.stop())
            except asyncio.CancelledError:
                pass

    async def stop(self) -> None:
        """Stop all running agents."""
        # Cancel first so _execute_actor sees we are not stopped.
        self.cancel()
        # Then stop the agents
        await super().stop()

    def cancel(self) -> None:
        """Cancel all running agents."""
        [agent.cancel() for agent in self.values()]

    def update_topic_index(self) -> None:
        """Update indices."""
        # keep mapping from topic name to set of agents.
        by_topic_index = self._by_topic
        for agent in self.values():
            for topic in agent.get_topic_names():
                by_topic_index[topic].add(agent)

    async def on_rebalance(self, revoked: Set[TP],
                           newly_assigned: Set[TP]) -> None:
        """Call when a rebalance is needed."""
        T = traced_from_parent_span()
        # for isolated_partitions agents we stop agents for revoked
        # partitions.
        for agent, tps in self._collect_agents_for_update(revoked).items():
            await T(agent.on_partitions_revoked)(tps)
        # for isolated_partitions agents we start agents for newly
        # assigned partitions
        for agent, tps in T(
                self._collect_agents_for_update)(newly_assigned).items():
            await T(agent.on_partitions_assigned)(tps)

    def _collect_agents_for_update(self,
                                   tps: Set[TP]) -> Dict[AgentT, Set[TP]]:
        by_agent: Dict[AgentT, Set[TP]] = defaultdict(set)
        for topic, tps in tp_set_to_map(tps).items():
            for agent in self._by_topic[topic]:
                by_agent[agent].update(tps)
        return by_agent
Exemple #8
0
class Consumer(base.Consumer):
    """Kafka consumer using :pypi:`aiokafka`."""

    logger = logger

    RebalanceListener: ClassVar[Type[ConsumerRebalanceListener]]
    RebalanceListener = ConsumerRebalanceListener

    _consumer: aiokafka.AIOKafkaConsumer
    _rebalance_listener: ConsumerRebalanceListener
    _active_partitions: Optional[Set[_TopicPartition]]
    _paused_partitions: Set[_TopicPartition]
    fetch_timeout: float = 10.0

    consumer_stopped_errors: ClassVar[Tuple[Type[BaseException], ...]] = (
        ConsumerStoppedError,
    )

    flow_active: bool = True
    can_resume_flow: Event

    def on_init(self) -> None:
        app = self.transport.app
        transport = cast(Transport, self.transport)
        self._rebalance_listener = self.RebalanceListener(self)
        if app.client_only:
            self._consumer = self._create_client_consumer(app, transport)
        else:
            self._consumer = self._create_worker_consumer(app, transport)
        self._active_partitions = None
        self._paused_partitions = set()
        self.can_resume_flow = Event()

    async def on_restart(self) -> None:
        self.on_init()

    def _get_active_partitions(self) -> Set[_TopicPartition]:
        tps = self._active_partitions
        if tps is None:
            # need aiokafka._TopicPartition, not faust.TP
            return self._set_active_tps(self._consumer.assignment())
        return tps

    def _set_active_tps(self,
                        tps: Set[_TopicPartition]) -> Set[_TopicPartition]:
        tps = self._active_partitions = set(tps)  # copy!
        tps.difference_update(self._paused_partitions)
        return tps

    def _create_worker_consumer(
            self,
            app: AppT,
            transport: 'Transport') -> aiokafka.AIOKafkaConsumer:
        conf = app.conf
        self._assignor = self.app.assignor
        return aiokafka.AIOKafkaConsumer(
            loop=self.loop,
            client_id=conf.broker_client_id,
            group_id=conf.id,
            bootstrap_servers=server_list(
                transport.url, transport.default_port),
            partition_assignment_strategy=[self._assignor],
            enable_auto_commit=False,
            auto_offset_reset='earliest',
            max_poll_records=conf.broker_max_poll_records,
            max_partition_fetch_bytes=1048576 * 4,
            fetch_max_wait_ms=1500,
            check_crcs=conf.broker_check_crcs,
            session_timeout_ms=int(conf.broker_session_timeout * 1000.0),
            heartbeat_interval_ms=int(conf.broker_heartbeat_interval * 1000.0),
            security_protocol="SSL" if conf.ssl_context else "PLAINTEXT",
            ssl_context=conf.ssl_context,
        )

    def _create_client_consumer(
            self,
            app: AppT,
            transport: 'Transport') -> aiokafka.AIOKafkaConsumer:
        return aiokafka.AIOKafkaConsumer(
            loop=self.loop,
            client_id=app.conf.broker_client_id,
            bootstrap_servers=server_list(
                transport.url, transport.default_port),
            enable_auto_commit=True,
            max_poll_records=app.conf.broker_max_poll_records,
            auto_offset_reset='earliest',
            check_crcs=app.conf.broker_check_crcs,
            security_protocol="SSL" if app.conf.ssl_context else "PLAINTEXT",
            ssl_context=app.conf.ssl_context,
        )

    async def create_topic(self,
                           topic: str,
                           partitions: int,
                           replication: int,
                           *,
                           config: Mapping[str, Any] = None,
                           timeout: Seconds = 30.0,
                           retention: Seconds = None,
                           compacting: bool = None,
                           deleting: bool = None,
                           ensure_created: bool = False) -> None:
        await cast(Transport, self.transport)._create_topic(
            self,
            self._consumer._client,
            topic,
            partitions,
            replication,
            config=config,
            timeout=int(want_seconds(timeout) * 1000.0),
            retention=int(want_seconds(retention) * 1000.0),
            compacting=compacting,
            deleting=deleting,
            ensure_created=ensure_created,
        )

    async def on_start(self) -> None:
        self.beacon.add(self._consumer)
        await self._consumer.start()

    async def subscribe(self, topics: Iterable[str]) -> None:
        # XXX pattern does not work :/
        self._consumer.subscribe(
            topics=set(topics),
            listener=self._rebalance_listener,
        )

    async def getmany(self,
                      timeout: float) -> AsyncIterator[Tuple[TP, Message]]:
        # Implementation for the Fetcher service.
        _consumer = self._consumer
        fetcher = _consumer._fetcher
        if _consumer._closed or fetcher._closed:
            raise ConsumerStoppedError()
        active_partitions = self._get_active_partitions()
        _next = next

        records: RecordMap = {}
        if not self.flow_active:
            await self.wait(self.can_resume_flow)
        if active_partitions:
            # Fetch records only if active partitions to avoid the risk of
            # fetching all partitions in the beginning when none of the
            # partitions is paused/resumed.
            records = await fetcher.fetched_records(
                active_partitions,
                timeout=timeout,
            )
        else:
            # We should still release to the event loop
            await self.sleep(1)
            if self.should_stop:
                return
        create_message = ConsumerMessage  # localize

        # records' contain mapping from TP to list of messages.
        # if there are two agents, consuming from topics t1 and t2,
        # normal order of iteration would be to process each
        # tp in the dict:
        #    for tp. messages in records.items():
        #        for message in messages:
        #           yield tp, message
        #
        # The problem with this, is if we have prefetched 16k records
        # for one partition, the other partitions won't even start processing
        # before those 16k records are completed.
        #
        # So we try round-robin between the tps instead:
        #
        #    iterators: Dict[TP, Iterator] = {
        #        tp: iter(messages)
        #        for tp, messages in records.items()
        #    }
        #    while iterators:
        #        for tp, messages in iterators.items():
        #            yield tp, next(messages)
        #            # remove from iterators if empty.
        #
        # The problem with this implementation is that
        # the records mapping is ordered by TP, so records.keys()
        # will look like this:
        #
        #  TP(topic='bar', partition=0)
        #  TP(topic='bar', partition=1)
        #  TP(topic='bar', partition=2)
        #  TP(topic='bar', partition=3)
        #  TP(topic='foo', partition=0)
        #  TP(topic='foo', partition=1)
        #  TP(topic='foo', partition=2)
        #  TP(topic='foo', partition=3)
        #
        # If there are 100 partitions for each topic,
        # it will process 100 items in the first topic, then 100 items
        # in the other topic, but even worse if partition counts
        # vary greatly, t1 has 1000 partitions and t2
        # has 1 partition, then t2 will end up being starved most of the time.
        #
        # We solve this by going round-robin through each topic.
        topic_index = self._records_to_topic_index(records, active_partitions)
        to_remove: Set[str] = set()
        sentinel = object()
        while topic_index:
            if not self.flow_active:
                break
            for topic in to_remove:
                topic_index.pop(topic, None)
            for topic, messages in topic_index.items():
                if not self.flow_active:
                    break
                item = _next(messages, sentinel)
                if item is sentinel:
                    # this topic is now empty,
                    # but we cannot remove from dict while iterating over it,
                    # so move that to the outer loop.
                    to_remove.add(topic)
                    continue
                tp, record = item  # type: ignore
                if tp in active_partitions:
                    highwater_mark = self._consumer.highwater(tp)
                    self.app.monitor.track_tp_end_offset(tp, highwater_mark)
                    # convert timestamp to seconds from int milliseconds.
                    timestamp: Optional[int] = record.timestamp
                    timestamp_s: float = cast(float, None)
                    if timestamp is not None:
                        timestamp_s = timestamp / 1000.0
                    yield tp, create_message(
                        record.topic,
                        record.partition,
                        record.offset,
                        timestamp_s,
                        record.timestamp_type,
                        record.key,
                        record.value,
                        record.checksum,
                        record.serialized_key_size,
                        record.serialized_value_size,
                        tp,
                    )

    def _records_to_topic_index(self,
                                records: RecordMap,
                                active_partitions: Set[TP]) -> TopicIndexMap:
        topic_index: TopicIndexMap = {}
        for tp, messages in records.items():
            try:
                entry = topic_index[tp.topic]
            except KeyError:
                entry = topic_index[tp.topic] = _TopicBuffer()
            entry.add(tp, messages)
        return topic_index

    def _new_topicpartition(self, topic: str, partition: int) -> TP:
        return cast(TP, _TopicPartition(topic, partition))

    def _new_offsetandmetadata(self, offset: int, meta: str) -> Any:
        return OffsetAndMetadata(offset, meta)

    async def on_stop(self) -> None:
        await super().on_stop()  # wait_empty
        await self.commit()
        await self._consumer.stop()
        transport = cast(Transport, self.transport)
        transport._topic_waiters.clear()

    async def perform_seek(self) -> None:
        await self.transition_with(CONSUMER_SEEKING, self._perform_seek())

    async def _perform_seek(self) -> None:
        read_offset = self._read_offset
        _committed_offsets = await self._consumer.seek_to_committed()
        committed_offsets = {
            _ensure_TP(tp): offset
            for tp, offset in _committed_offsets.items()
            if offset is not None
        }
        read_offset.update({
            tp: offset if offset else None
            for tp, offset in committed_offsets.items()
        })
        self._committed_offset.update(committed_offsets)

    async def _commit(self, offsets: Mapping[TP, Tuple[int, str]]) -> bool:
        table = terminal.logtable(
            [(str(tp), str(offset), meta)
             for tp, (offset, meta) in offsets.items()],
            title='Commit Offsets',
            headers=['TP', 'Offset', 'Metadata'],
        )
        self.log.dev('COMMITTING OFFSETS:\n%s', table)
        try:
            assignment = self.assignment()
            commitable: Dict[TP, OffsetAndMetadata] = {}
            revoked: Dict[TP, OffsetAndMetadata] = {}
            commitable_offsets: Dict[TP, int] = {}
            for tp, (offset, meta) in offsets.items():
                offset_and_metadata = self._new_offsetandmetadata(offset, meta)
                if tp in assignment:
                    commitable_offsets[tp] = offset
                    commitable[tp] = offset_and_metadata
                else:
                    revoked[tp] = offset_and_metadata
            if revoked:
                self.log.info(
                    'Discarded commit for revoked partitions that '
                    'will be eventually processed again: %r',
                    revoked,
                )
            if not commitable:
                return False
            with flight_recorder(self.log, timeout=300.0) as on_timeout:
                on_timeout.info('+aiokafka_consumer.commit()')
                await self._consumer.commit(commitable)
                on_timeout.info('-aiokafka._consumer.commit()')
            self._committed_offset.update(commitable_offsets)
            self.app.monitor.on_tp_commit(commitable_offsets)
            self._last_batch = None
            return True
        except CommitFailedError as exc:
            if 'already rebalanced' in str(exc):
                return False
            self.log.exception(f'Committing raised exception: %r', exc)
            await self.crash(exc)
            return False
        except IllegalStateError as exc:
            self.log.exception(f'Got exception: {exc}\n'
                               f'Current assignment: {self.assignment()}')
            await self.crash(exc)
            return False

    def stop_flow(self) -> None:
        self.flow_active = False
        self.can_resume_flow.clear()

    def resume_flow(self) -> None:
        self.flow_active = True
        self.can_resume_flow.set()

    def pause_partitions(self, tps: Iterable[TP]) -> None:
        tpset = set(tps)
        self._get_active_partitions().difference_update(tpset)
        self._paused_partitions.update(tpset)

    def resume_partitions(self, tps: Iterable[TP]) -> None:
        tpset = set(tps)
        self._get_active_partitions().update(tps)
        self._paused_partitions.difference_update(tpset)

    async def position(self, tp: TP) -> Optional[int]:
        return await self._consumer.position(tp)

    async def _seek_to_beginning(self, *partitions: TP) -> None:
        self.log.dev('SEEK TO BEGINNING: %r', partitions)
        self._read_offset.update((_ensure_TP(tp), None) for tp in partitions)
        await self._consumer.seek_to_beginning(*(
            self._new_topicpartition(tp.topic, tp.partition)
            for tp in partitions
        ))

    async def seek(self, partition: TP, offset: int) -> None:
        self.log.dev('SEEK %r -> %r', partition, offset)
        # reset livelock detection
        self._last_batch = None
        # set new read offset so we will reread messages
        self._read_offset[_ensure_TP(partition)] = offset if offset else None
        self._consumer.seek(partition, offset)

    def assignment(self) -> Set[TP]:
        return cast(Set[TP], self._consumer.assignment())

    def highwater(self, tp: TP) -> int:
        return self._consumer.highwater(tp)

    async def earliest_offsets(self,
                               *partitions: TP) -> MutableMapping[TP, int]:
        return await self._consumer.beginning_offsets(partitions)

    async def highwaters(self, *partitions: TP) -> MutableMapping[TP, int]:
        return await self._consumer.end_offsets(partitions)

    def close(self) -> None:
        self._consumer.set_close()
        self._consumer._coordinator.set_close()
Exemple #9
0
 def signal_recovery_reset(self) -> Event:
     """Event used to signal that recovery is restarting."""
     if self._signal_recovery_reset is None:
         self._signal_recovery_reset = Event(loop=self.loop)
     return self._signal_recovery_reset
Exemple #10
0
 def signal_recovery_end(self) -> Event:
     if self._signal_recovery_end is None:
         self._signal_recovery_end = Event(loop=self.loop)
     return self._signal_recovery_end
Exemple #11
0
class Recovery(Service):

    app: AppT

    tables: TableManager

    stats_interval: float = 5.0

    #: Set of standby tps.
    standby_tps: Set[TP]

    #: Set of active tps.
    active_tps: Set[TP]

    #: Mapping from TP to table
    tp_to_table: MutableMapping[TP, CollectionT]

    #: Active offset by TP.
    active_offsets: Counter[TP]

    #: Standby offset by TP.
    standby_offsets: Counter[TP]

    #: Mapping of highwaters by tp.
    highwaters: Counter[TP]

    #: Active highwaters by TP.
    active_highwaters: Counter[TP]

    #: Standby highwaters by TP.
    standby_highwaters: Counter[TP]

    _signal_recovery_start: Optional[Event] = None
    _signal_recovery_end: Optional[Event] = None
    _signal_recovery_reset: Optional[Event] = None

    completed: Event
    in_recovery: bool = False
    recovery_delay: float

    #: Changelog event buffers by table.
    #: These are filled by background task `_slurp_changelog`,
    #: and need to be flushed before starting new recovery/stopping.
    buffers: MutableMapping[CollectionT, List[EventT]]

    #: Cache of buffer size by TopicPartitiojn.
    buffer_sizes: MutableMapping[TP, int]

    def __init__(self, app: AppT, tables: TableManagerT,
                 **kwargs: Any) -> None:
        self.app = app
        self.tables = cast(TableManager, tables)

        self.standby_tps = set()
        self.active_tps = set()

        self.tp_to_table = {}
        self.active_offsets = Counter()
        self.standby_offsets = Counter()

        self.active_highwaters = Counter()
        self.standby_highwaters = Counter()
        self.completed = Event()

        self.buffers = defaultdict(list)
        self.buffer_sizes = {}
        self.recovery_delay = self.app.conf.stream_recovery_delay

        super().__init__(**kwargs)

    @property
    def signal_recovery_start(self) -> Event:
        if self._signal_recovery_start is None:
            self._signal_recovery_start = Event(loop=self.loop)
        return self._signal_recovery_start

    @property
    def signal_recovery_end(self) -> Event:
        if self._signal_recovery_end is None:
            self._signal_recovery_end = Event(loop=self.loop)
        return self._signal_recovery_end

    @property
    def signal_recovery_reset(self) -> Event:
        if self._signal_recovery_reset is None:
            self._signal_recovery_reset = Event(loop=self.loop)
        return self._signal_recovery_reset

    async def on_stop(self) -> None:
        # Flush buffers when stopping.
        self.flush_buffers()

    def add_active(self, table: CollectionT, tp: TP) -> None:
        self.active_tps.add(tp)
        self._add(table, tp, self.active_offsets)

    def add_standby(self, table: CollectionT, tp: TP) -> None:
        table = self.tables._changelogs[tp.topic]
        self.standby_tps.add(tp)
        self._add(table, tp, self.standby_offsets)

    def _add(self, table: CollectionT, tp: TP, offsets: Counter[TP]) -> None:
        self.tp_to_table[tp] = table
        persisted_offset = table.persisted_offset(tp)
        if persisted_offset is not None:
            offsets[tp] = persisted_offset
        offsets.setdefault(tp, -1)

    def revoke(self, tp: TP) -> None:
        self.standby_offsets.pop(tp, None)
        self.standby_highwaters.pop(tp, None)
        self.active_offsets.pop(tp, None)
        self.active_highwaters.pop(tp, None)

    async def on_partitions_revoked(self, revoked: Set[TP]) -> None:
        self.flush_buffers()
        self.signal_recovery_reset.set()

    async def on_rebalance(self, assigned: Set[TP], revoked: Set[TP],
                           newly_assigned: Set[TP]) -> None:
        assigned_standbys = self.app.assignor.assigned_standbys()
        assigned_actives = self.app.assignor.assigned_actives()

        for tp in revoked:
            self.revoke(tp)

        self.standby_tps.clear()
        self.active_tps.clear()

        for tp in assigned_standbys:
            table = self.tables._changelogs.get(tp.topic)
            if table is not None:
                self.add_standby(table, tp)
        for tp in assigned_actives:
            table = self.tables._changelogs.get(tp.topic)
            if table is not None:
                self.add_active(table, tp)

        active_offsets = {
            tp: offset
            for tp, offset in self.active_offsets.items()
            if tp in self.active_tps
        }
        self.active_offsets.clear()
        self.active_offsets.update(active_offsets)

        self.signal_recovery_reset.clear()
        self.signal_recovery_start.set()

    async def _resume_streams(self) -> None:
        app = self.app
        consumer = app.consumer
        await app.on_rebalance_complete.send()
        # Resume partitions and start fetching.
        self.log.info('Resuming flow...')
        consumer.resume_flow()
        app.flow_control.resume()
        self.log.info('Seek stream partitions to committed offsets.')
        await self._wait(consumer.perform_seek())
        self.completed.set()
        assignment = consumer.assignment()
        self.log.dev('Resume stream partitions')
        consumer.resume_partitions(assignment)
        # finally make sure the fetcher is running.
        await app._fetcher.maybe_start()
        app.rebalancing = False
        self.log.info('Worker ready')

    @Service.task
    async def _restart_recovery(self) -> None:
        consumer = self.app.consumer
        active_tps = self.active_tps
        standby_tps = self.standby_tps
        standby_offsets = self.standby_offsets
        standby_highwaters = self.standby_highwaters
        assigned_active_tps = self.active_tps
        assigned_standby_tps = self.standby_tps
        active_offsets = self.active_offsets
        standby_offsets = self.standby_offsets
        active_highwaters = self.active_highwaters

        while not self.should_stop:
            self.log.dev('WAITING FOR NEXT RECOVERY TO START')
            self.signal_recovery_reset.clear()
            self.in_recovery = False
            if await self.wait_for_stopped(self.signal_recovery_start):
                self.signal_recovery_start.clear()
                break  # service was stopped
            self.signal_recovery_start.clear()

            try:
                await self._wait(asyncio.sleep(self.recovery_delay))

                if not self.tables:
                    # If there are no tables -- simply resume streams
                    await self._resume_streams()
                    continue

                self.in_recovery = True
                # Must flush any buffers before starting rebalance.
                self.flush_buffers()
                await self._wait(self.app._producer.flush())

                self.log.dev('Build highwaters for active partitions')
                await self._wait(
                    self._build_highwaters(consumer, assigned_active_tps,
                                           active_highwaters, 'active'))

                self.log.dev('Build offsets for active partitions')
                await self._wait(
                    self._build_offsets(consumer, assigned_active_tps,
                                        active_offsets, 'active'))

                self.log.dev('Build offsets for standby partitions')
                await self._wait(
                    self._build_offsets(consumer, assigned_standby_tps,
                                        standby_offsets, 'standby'))

                self.log.dev('Seek offsets for active partitions')
                await self._wait(
                    self._seek_offsets(consumer, assigned_active_tps,
                                       active_offsets, 'active'))

                if self.need_recovery():
                    self.log.info('Restoring state from changelog topics...')
                    consumer.resume_partitions(active_tps)
                    # Resume partitions and start fetching.
                    self.log.info('Resuming flow...')
                    consumer.resume_flow()
                    await self.app._fetcher.maybe_start()
                    self.app.flow_control.resume()

                    # Wait for actives to be up to date.
                    # This signal will be set by _slurp_changelogs
                    self.signal_recovery_end.clear()
                    await self._wait(self.signal_recovery_end)

                    # recovery done.
                    self.log.info('Done reading from changelog topics')
                    consumer.pause_partitions(active_tps)
                else:
                    self.log.info('Resuming flow...')
                    consumer.resume_flow()
                    self.app.flow_control.resume()

                self.log.info('Recovery complete')
                self.in_recovery = False

                if standby_tps:
                    self.log.info('Starting standby partitions...')

                    self.log.dev('Seek standby offsets')
                    await self._wait(
                        self._seek_offsets(consumer, standby_tps,
                                           standby_offsets, 'standby'))

                    self.log.dev('Build standby highwaters')
                    await self._wait(
                        self._build_highwaters(
                            consumer,
                            standby_tps,
                            standby_highwaters,
                            'standby',
                        ), )

                    self.log.dev('Resume standby partitions')
                    consumer.resume_partitions(standby_tps)

                # Pause all our topic partitions,
                # to make sure we don't fetch any more records from them.
                await self._wait(asyncio.sleep(0.1))  # still needed?
                await self._wait(self.on_recovery_completed())
            except RebalanceAgain:
                self.log.dev('RAISED REBALANCE AGAIN')
                continue  # another rebalance started
            except ServiceStopped:
                self.log.dev('RAISED SERVICE STOPPED')
                break  # service was stopped
            # restart - wait for next rebalance.
        self.in_recovery = False

    async def _wait(self, coro: WaitArgT) -> None:
        wait_result = await self.wait_first(
            coro,
            self.signal_recovery_reset,
            self.signal_recovery_start,
        )
        if wait_result.stopped:
            # service was stopped.
            raise ServiceStopped()
        elif self.signal_recovery_start in wait_result.done:
            # another rebalance started
            raise RebalanceAgain()
        elif self.signal_recovery_reset in wait_result.done:
            raise RebalanceAgain()
        else:
            return None

    async def on_recovery_completed(self) -> None:
        consumer = self.app.consumer
        self.log.info('Restore complete!')
        await self.app.on_rebalance_complete.send()
        # This needs to happen if all goes well
        callback_coros = []
        for table in self.tables.values():
            callback_coros.append(table.call_recover_callbacks())
        if callback_coros:
            await asyncio.wait(callback_coros)
        self.log.info('Seek stream partitions to committed offsets.')
        await consumer.perform_seek()
        self.completed.set()
        assignment = consumer.assignment()
        self.log.dev('Resume stream partitions')
        consumer.resume_partitions(
            {tp
             for tp in assignment if not self._is_changelog_tp(tp)})
        # finally make sure the fetcher is running.
        await self.app._fetcher.maybe_start()
        self.app.rebalancing = False
        self.log.info('Worker ready')

    async def _build_highwaters(self, consumer: ConsumerT, tps: Set[TP],
                                destination: Counter[TP], title: str) -> None:
        # -- Build highwater
        highwaters = await consumer.highwaters(*tps)
        highwaters = {
            # FIXME the -1 here is because of the way we commit offsets
            tp: value - 1
            for tp, value in highwaters.items()
        }
        table = terminal.logtable(
            [[k.topic, str(k.partition), str(v)]
             for k, v in highwaters.items()],
            title=f'Highwater - {title.capitalize()}',
            headers=['topic', 'partition', 'highwater'],
        )
        self.log.info('Highwater for %s changelog partitions:\n%s', title,
                      table)
        destination.clear()
        destination.update(highwaters)

    async def _build_offsets(self, consumer: ConsumerT, tps: Set[TP],
                             destination: Counter[TP], title: str) -> None:
        # -- Update offsets
        # Offsets may have been compacted, need to get to the recent ones
        earliest = await consumer.earliest_offsets(*tps)
        # FIXME To be consistent with the offset -1 logic
        earliest = {tp: offset - 1 for tp, offset in earliest.items()}
        for tp in tps:
            destination[tp] = max(destination[tp], earliest[tp])
        table = terminal.logtable(
            [(k.topic, k.partition, v) for k, v in destination.items()],
            title=f'Reading Starts At - {title.capitalize()}',
            headers=['topic', 'partition', 'offset'],
        )
        self.log.info('%s offsets at start of reading:\n%s', title, table)

    async def _seek_offsets(self, consumer: ConsumerT, tps: Set[TP],
                            offsets: Counter[TP], title: str) -> None:
        # Seek to new offsets
        for tp in tps:
            offset = offsets[tp]
            if offset == -1:
                offset = 0
            # FIXME Remove check when fixed offset-1 discrepancy
            await consumer.seek(tp, offset)
            assert await consumer.position(tp) == offset

    @Service.task
    async def _slurp_changelogs(self) -> None:
        changelog_queue = self.tables.changelog_queue
        tp_to_table = self.tp_to_table

        active_tps = self.active_tps
        standby_tps = self.standby_tps
        active_offsets = self.active_offsets
        standby_offsets = self.standby_offsets

        buffers = self.buffers
        buffer_sizes = self.buffer_sizes

        while not self.should_stop:
            event: EventT = await changelog_queue.get()
            message = event.message
            tp = message.tp
            offset = message.offset

            offsets: Counter[TP]
            bufsize = buffer_sizes.get(tp)
            if tp in active_tps:
                table = tp_to_table[tp]
                offsets = active_offsets
                if bufsize is None:
                    bufsize = buffer_sizes[tp] = table.recovery_buffer_size
            elif tp in standby_tps:
                table = tp_to_table[tp]
                offsets = standby_offsets
                if bufsize is None:
                    bufsize = buffer_sizes[tp] = table.standby_buffer_size
            else:
                continue

            seen_offset = offsets.get(tp, -1)
            if offset > seen_offset:
                offsets[tp] = offset
                buf = buffers[table]
                buf.append(event)
                await table.on_changelog_event(event)
                if len(buf) >= bufsize:
                    table.apply_changelog_batch(buf)
                    buf.clear()
            if self.in_recovery and not self.active_remaining_total():
                # apply anything stuck in the buffers
                self.flush_buffers()
                self.in_recovery = False
                self.signal_recovery_end.set()

    def flush_buffers(self) -> None:
        for table, buffer in self.buffers.items():
            table.apply_changelog_batch(buffer)
            buffer.clear()

    def need_recovery(self) -> bool:
        return self.active_highwaters != self.active_offsets

    def active_remaining(self) -> Counter[TP]:
        return self.active_highwaters - self.active_offsets

    def standby_remaining(self) -> Counter[TP]:
        return self.standby_highwaters - self.standby_offsets

    def active_remaining_total(self) -> int:
        return sum(self.active_remaining().values())

    def standby_remaining_total(self) -> int:
        return sum(self.standby_remaining().values())

    def active_stats(self) -> MutableMapping[TP, Tuple[int, int, int]]:
        offsets = self.active_offsets
        return {
            tp: (highwater, offsets[tp], highwater - offsets[tp])
            for tp, highwater in self.active_highwaters.items()
            if highwater - offsets[tp] != 0
        }

    def standby_stats(self) -> MutableMapping[TP, Tuple[int, int, int]]:
        offsets = self.standby_offsets
        return {
            tp: (highwater, offsets[tp], highwater - offsets[tp])
            for tp, highwater in self.standby_highwaters.items()
            if highwater - offsets[tp] != 0
        }

    @Service.task
    async def _publish_stats(self) -> None:
        while not self.should_stop:
            if self.in_recovery:
                self.log.info('Still fetching. Remaining: %s',
                              self.active_stats())
            await self.sleep(self.stats_interval)

    def _is_changelog_tp(self, tp: TP) -> bool:
        return tp.topic in self.tables.changelog_topics
Exemple #12
0
def test_repr():
    ev = Event()
    assert repr(ev)
    ev._waiters = [1, 2, 3]
    assert repr(ev)
Exemple #13
0
 def _new_shutdown_done_event(self) -> Event:
     return Event(loop=self._loop)
Exemple #14
0
 def _new_force_kill_event(self) -> Event:
     return Event(loop=self._loop)
Exemple #15
0
 def signal_recovery_start(self) -> Event:
     """Event used to signal that recovery has started."""
     if self._signal_recovery_start is None:
         self._signal_recovery_start = Event(loop=self.loop)
     return self._signal_recovery_start
Exemple #16
0
 def signal_recovery_end(self) -> Event:
     """Event used to signal that recovery has ended."""
     if self._signal_recovery_end is None:
         self._signal_recovery_end = Event(loop=self.loop)
     return self._signal_recovery_end
Exemple #17
0
 def signal_recovery_reset(self) -> Event:
     if self._signal_recovery_reset is None:
         self._signal_recovery_reset = Event(loop=self.loop)
     return self._signal_recovery_reset
Exemple #18
0
class Recovery(Service):
    """Service responsible for recovering tables from changelog topics."""

    app: AppT

    tables: _TableManager

    stats_interval: float = 5.0

    #: Set of standby topic partitions.
    standby_tps: Set[TP]

    #: Set of active topic partitions.
    active_tps: Set[TP]

    actives_for_table: MutableMapping[CollectionT, Set[TP]]
    standbys_for_table: MutableMapping[CollectionT, Set[TP]]

    #: Mapping from topic partition to table
    tp_to_table: MutableMapping[TP, CollectionT]

    #: Active offset by topic partition.
    active_offsets: Counter[TP]

    #: Standby offset by topic partition.
    standby_offsets: Counter[TP]

    #: Mapping of highwaters by topic partition.
    highwaters: Counter[TP]

    #: Active highwaters by topic partition.
    active_highwaters: Counter[TP]

    #: Standby highwaters by topic partition.
    standby_highwaters: Counter[TP]

    _signal_recovery_start: Optional[Event] = None
    _signal_recovery_end: Optional[Event] = None
    _signal_recovery_reset: Optional[Event] = None

    completed: Event
    in_recovery: bool = False
    standbys_pending: bool = False
    recovery_delay: float

    #: Changelog event buffers by table.
    #: These are filled by background task `_slurp_changelog`,
    #: and need to be flushed before starting new recovery/stopping.
    buffers: MutableMapping[CollectionT, List[EventT]]

    #: Cache of buffer size by topic partition..
    buffer_sizes: MutableMapping[TP, int]

    _recovery_span: Optional[opentracing.Span] = None
    _actives_span: Optional[opentracing.Span] = None
    _standbys_span: Optional[opentracing.Span] = None

    def __init__(self, app: AppT, tables: TableManagerT,
                 **kwargs: Any) -> None:
        self.app = app
        self.tables = cast(_TableManager, tables)

        self.standby_tps = set()
        self.active_tps = set()

        self.tp_to_table = {}
        self.active_offsets = Counter()
        self.standby_offsets = Counter()

        self.active_highwaters = Counter()
        self.standby_highwaters = Counter()
        self.completed = Event()

        self.buffers = defaultdict(list)
        self.buffer_sizes = {}
        self.recovery_delay = self.app.conf.stream_recovery_delay

        self.actives_for_table = defaultdict(set)
        self.standbys_for_table = defaultdict(set)

        super().__init__(**kwargs)

    @property
    def signal_recovery_start(self) -> Event:
        """Event used to signal that recovery has started."""
        if self._signal_recovery_start is None:
            self._signal_recovery_start = Event(loop=self.loop)
        return self._signal_recovery_start

    @property
    def signal_recovery_end(self) -> Event:
        """Event used to signal that recovery has ended."""
        if self._signal_recovery_end is None:
            self._signal_recovery_end = Event(loop=self.loop)
        return self._signal_recovery_end

    @property
    def signal_recovery_reset(self) -> Event:
        """Event used to signal that recovery is restarting."""
        if self._signal_recovery_reset is None:
            self._signal_recovery_reset = Event(loop=self.loop)
        return self._signal_recovery_reset

    async def on_stop(self) -> None:
        """Call when recovery service stops."""
        # Flush buffers when stopping.
        self.flush_buffers()

    def add_active(self, table: CollectionT, tp: TP) -> None:
        """Add changelog partition to be used for active recovery."""
        self.active_tps.add(tp)
        self.actives_for_table[table].add(tp)
        self._add(table, tp, self.active_offsets)

    def add_standby(self, table: CollectionT, tp: TP) -> None:
        """Add changelog partition to be used for standby recovery."""
        self.standby_tps.add(tp)
        self.standbys_for_table[table].add(tp)
        self._add(table, tp, self.standby_offsets)

    def _add(self, table: CollectionT, tp: TP, offsets: Counter[TP]) -> None:
        self.tp_to_table[tp] = table
        persisted_offset = table.persisted_offset(tp)
        if persisted_offset is not None:
            offsets[tp] = persisted_offset
        offsets.setdefault(tp, None)  # type: ignore

    def revoke(self, tp: TP) -> None:
        """Revoke assignment of table changelog partition."""
        self.standby_offsets.pop(tp, None)
        self.standby_highwaters.pop(tp, None)
        self.active_offsets.pop(tp, None)
        self.active_highwaters.pop(tp, None)

    def on_partitions_revoked(self, revoked: Set[TP]) -> None:
        """Call when rebalancing and partitions are revoked."""
        T = traced_from_parent_span()
        T(self.flush_buffers)()
        self.signal_recovery_reset.set()

    async def on_rebalance(self, assigned: Set[TP], revoked: Set[TP],
                           newly_assigned: Set[TP]) -> None:
        """Call when cluster is rebalancing."""
        app = self.app
        assigned_standbys = app.assignor.assigned_standbys()
        assigned_actives = app.assignor.assigned_actives()

        for tp in revoked:
            self.revoke(tp)

        self.standby_tps.clear()
        self.active_tps.clear()
        self.actives_for_table.clear()
        self.standbys_for_table.clear()

        for tp in assigned_standbys:
            table = self.tables._changelogs.get(tp.topic)
            if table is not None:
                self.add_standby(table, tp)
        for tp in assigned_actives:
            table = self.tables._changelogs.get(tp.topic)
            if table is not None:
                self.add_active(table, tp)

        active_offsets = {
            tp: offset
            for tp, offset in self.active_offsets.items()
            if tp in self.active_tps
        }
        self.active_offsets.clear()
        self.active_offsets.update(active_offsets)

        rebalancing_span = cast(_App, self.app)._rebalancing_span
        if app.tracer and rebalancing_span:
            self._recovery_span = app.tracer.get_tracer('_faust').start_span(
                'recovery',
                child_of=rebalancing_span,
            )
            app._span_add_default_tags(self._recovery_span)
        self.signal_recovery_reset.clear()
        self.signal_recovery_start.set()

    async def _resume_streams(self) -> None:
        app = self.app
        consumer = app.consumer
        await app.on_rebalance_complete.send()
        # Resume partitions and start fetching.
        self.log.info('Resuming flow...')
        consumer.resume_flow()
        app.flow_control.resume()
        assignment = consumer.assignment()
        if assignment:
            self.log.info('Seek stream partitions to committed offsets.')
            await self._wait(consumer.perform_seek())
            self.log.dev('Resume stream partitions')
            consumer.resume_partitions(assignment)
        else:
            self.log.info('Resuming streams with empty assignment')
        self.completed.set()
        # finally make sure the fetcher is running.
        await cast(_App, app)._fetcher.maybe_start()
        self.tables.on_actives_ready()
        self.tables.on_standbys_ready()
        app.on_rebalance_end()
        self.log.info('Worker ready')

    @Service.task
    async def _restart_recovery(self) -> None:
        consumer = self.app.consumer
        active_tps = self.active_tps
        standby_tps = self.standby_tps
        standby_offsets = self.standby_offsets
        standby_highwaters = self.standby_highwaters
        assigned_active_tps = self.active_tps
        assigned_standby_tps = self.standby_tps
        active_offsets = self.active_offsets
        standby_offsets = self.standby_offsets
        active_highwaters = self.active_highwaters

        while not self.should_stop:
            self.log.dev('WAITING FOR NEXT RECOVERY TO START')
            self.signal_recovery_reset.clear()
            self.in_recovery = False
            if await self.wait_for_stopped(self.signal_recovery_start):
                self.signal_recovery_start.clear()
                break  # service was stopped
            self.signal_recovery_start.clear()

            span: Any = None
            spans: list = []
            tracer: Optional[opentracing.Tracer] = None
            if self.app.tracer:
                tracer = self.app.tracer.get_tracer('_faust')
            if tracer is not None and self._recovery_span:
                span = tracer.start_span('recovery-thread',
                                         child_of=self._recovery_span)
                self.app._span_add_default_tags(span)
                spans.extend([span, self._recovery_span])
            T = traced_from_parent_span(span)

            try:
                await self._wait(T(asyncio.sleep)(self.recovery_delay))

                if not self.tables:
                    # If there are no tables -- simply resume streams
                    await T(self._resume_streams)()
                    for _span in spans:
                        finish_span(_span)
                    continue

                self.in_recovery = True
                self.standbys_pending = True
                # Must flush any buffers before starting rebalance.
                T(self.flush_buffers)()
                producer = cast(_App, self.app)._producer
                if producer is not None:
                    await self._wait(T(producer.flush)())

                self.log.dev('Build highwaters for active partitions')
                await self._wait(
                    T(self._build_highwaters)(consumer, assigned_active_tps,
                                              active_highwaters, 'active'))

                self.log.dev('Build offsets for active partitions')
                await self._wait(
                    T(self._build_offsets)(consumer, assigned_active_tps,
                                           active_offsets, 'active'))

                for tp in assigned_active_tps:
                    if active_offsets[tp] > active_highwaters[tp]:
                        raise ConsistencyError(
                            E_PERSISTED_OFFSET.format(
                                tp,
                                active_offsets[tp],
                                active_highwaters[tp],
                            ), )

                self.log.dev('Build offsets for standby partitions')
                await self._wait(
                    T(self._build_offsets)(consumer, assigned_standby_tps,
                                           standby_offsets, 'standby'))

                self.log.dev('Seek offsets for active partitions')
                await self._wait(
                    T(self._seek_offsets)(consumer, assigned_active_tps,
                                          active_offsets, 'active'))

                if self.need_recovery():
                    self.log.info('Restoring state from changelog topics...')
                    T(consumer.resume_partitions)(active_tps)
                    # Resume partitions and start fetching.
                    self.log.info('Resuming flow...')
                    T(consumer.resume_flow)()
                    await T(cast(_App, self.app)._fetcher.maybe_start)()
                    T(self.app.flow_control.resume)()

                    # Wait for actives to be up to date.
                    # This signal will be set by _slurp_changelogs
                    if tracer is not None and span:
                        self._actives_span = tracer.start_span(
                            'recovery-actives',
                            child_of=span,
                            tags={'Active-Stats': self.active_stats()},
                        )
                        self.app._span_add_default_tags(span)
                    try:
                        self.signal_recovery_end.clear()
                        await self._wait(self.signal_recovery_end)
                    except Exception as exc:
                        finish_span(self._actives_span, error=exc)
                    else:
                        finish_span(self._actives_span)
                    finally:
                        self._actives_span = None

                    # recovery done.
                    self.log.info('Done reading from changelog topics')
                    T(consumer.pause_partitions)(active_tps)
                else:
                    self.log.info('Resuming flow...')
                    T(consumer.resume_flow)()
                    T(self.app.flow_control.resume)()

                self.log.info('Recovery complete')
                if span:
                    span.set_tag('Recovery-Completed', True)
                self.in_recovery = False

                if standby_tps:
                    self.log.info('Starting standby partitions...')

                    self.log.dev('Seek standby offsets')
                    await self._wait(
                        T(self._seek_offsets)(consumer, standby_tps,
                                              standby_offsets, 'standby'))

                    self.log.dev('Build standby highwaters')
                    await self._wait(
                        T(self._build_highwaters)(
                            consumer,
                            standby_tps,
                            standby_highwaters,
                            'standby',
                        ), )

                    for tp in standby_tps:
                        if standby_offsets[tp] > standby_highwaters[tp]:
                            raise ConsistencyError(
                                E_PERSISTED_OFFSET.format(
                                    tp,
                                    standby_offsets[tp],
                                    standby_highwaters[tp],
                                ), )

                    if tracer is not None and span:
                        self._standbys_span = tracer.start_span(
                            'recovery-standbys',
                            child_of=span,
                            tags={'Standby-Stats': self.standby_stats()},
                        )
                        self.app._span_add_default_tags(span)
                    self.log.dev('Resume standby partitions')
                    T(consumer.resume_partitions)(standby_tps)

                # Pause all our topic partitions,
                # to make sure we don't fetch any more records from them.
                await self._wait(asyncio.sleep(0.1))  # still needed?
                await self._wait(T(self.on_recovery_completed)())
            except RebalanceAgain as exc:
                self.log.dev('RAISED REBALANCE AGAIN')
                for _span in spans:
                    finish_span(_span, error=exc)
                continue  # another rebalance started
            except ServiceStopped as exc:
                self.log.dev('RAISED SERVICE STOPPED')
                for _span in spans:
                    finish_span(_span, error=exc)
                break  # service was stopped
            except Exception as exc:
                for _span in spans:
                    finish_span(_span, error=exc)
                raise
            else:
                for _span in spans:
                    finish_span(_span)
            # restart - wait for next rebalance.
        self.in_recovery = False

    async def _wait(self, coro: WaitArgT) -> None:
        wait_result = await self.wait_first(
            coro,
            self.signal_recovery_reset,
            self.signal_recovery_start,
        )
        if wait_result.stopped:
            # service was stopped.
            raise ServiceStopped()
        elif self.signal_recovery_start in wait_result.done:
            # another rebalance started
            raise RebalanceAgain()
        elif self.signal_recovery_reset in wait_result.done:
            raise RebalanceAgain()
        else:
            return None

    async def on_recovery_completed(self) -> None:
        """Call when active table recovery is completed."""
        consumer = self.app.consumer
        self.log.info('Restore complete!')
        await self.app.on_rebalance_complete.send()
        # This needs to happen if all goes well
        callback_coros = [
            table.on_recovery_completed(
                self.actives_for_table[table],
                self.standbys_for_table[table],
            ) for table in self.tables.values()
        ]
        if callback_coros:
            await asyncio.wait(callback_coros)
        assignment = consumer.assignment()
        if assignment:
            self.log.info('Seek stream partitions to committed offsets.')
            await consumer.perform_seek()
        self.completed.set()
        self.log.dev('Resume stream partitions')
        consumer.resume_partitions(
            {tp
             for tp in assignment if not self._is_changelog_tp(tp)})
        # finally make sure the fetcher is running.
        await cast(_App, self.app)._fetcher.maybe_start()
        self.tables.on_actives_ready()
        if not self.app.assignor.assigned_standbys():
            self.tables.on_standbys_ready()
        self.app.on_rebalance_end()
        self.log.info('Worker ready')

    async def _build_highwaters(self, consumer: ConsumerT, tps: Set[TP],
                                destination: Counter[TP], title: str) -> None:
        # -- Build highwater
        highwaters = await consumer.highwaters(*tps)
        highwaters = {
            # FIXME the -1 here is because of the way we commit offsets
            tp: value - 1 if value is not None else -1
            for tp, value in highwaters.items()
        }
        table = terminal.logtable(
            [[k.topic, str(k.partition), str(v)]
             for k, v in highwaters.items()],
            title=f'Highwater - {title.capitalize()}',
            headers=['topic', 'partition', 'highwater'],
        )
        self.log.info('Highwater for %s changelog partitions:\n%s', title,
                      table)
        destination.clear()
        destination.update(highwaters)

    async def _build_offsets(self, consumer: ConsumerT, tps: Set[TP],
                             destination: Counter[TP], title: str) -> None:
        # -- Update offsets
        # Offsets may have been compacted, need to get to the recent ones
        earliest = await consumer.earliest_offsets(*tps)
        # FIXME To be consistent with the offset -1 logic
        earliest = {tp: offset - 1 for tp, offset in earliest.items()}
        for tp in tps:
            last_value = destination[tp]
            new_value = earliest[tp]

            if last_value is None:
                destination[tp] = new_value
            elif new_value is None:
                destination[tp] = last_value
            else:
                destination[tp] = max(last_value, new_value)
        table = terminal.logtable(
            [[k.topic, str(k.partition), str(v)]
             for k, v in destination.items()],
            title=f'Reading Starts At - {title.capitalize()}',
            headers=['topic', 'partition', 'offset'],
        )
        self.log.info('%s offsets at start of reading:\n%s', title, table)

    async def _seek_offsets(self, consumer: ConsumerT, tps: Set[TP],
                            offsets: Counter[TP], title: str) -> None:
        # Seek to new offsets
        new_offsets = {}
        for tp in tps:
            offset = offsets[tp]
            if offset == -1:
                offset = 0
            new_offsets[tp] = offset
        # FIXME Remove check when fixed offset-1 discrepancy
        await consumer.seek_wait(new_offsets)

    @Service.task
    async def _slurp_changelogs(self) -> None:
        changelog_queue = self.tables.changelog_queue
        tp_to_table = self.tp_to_table

        active_tps = self.active_tps
        standby_tps = self.standby_tps
        active_offsets = self.active_offsets
        standby_offsets = self.standby_offsets

        buffers = self.buffers
        buffer_sizes = self.buffer_sizes

        while not self.should_stop:
            event: EventT = await changelog_queue.get()
            message = event.message
            tp = message.tp
            offset = message.offset

            offsets: Counter[TP]
            bufsize = buffer_sizes.get(tp)
            if tp in active_tps:
                table = tp_to_table[tp]
                offsets = active_offsets
                if bufsize is None:
                    bufsize = buffer_sizes[tp] = table.recovery_buffer_size
            elif tp in standby_tps:
                table = tp_to_table[tp]
                offsets = standby_offsets
                if bufsize is None:
                    bufsize = buffer_sizes[tp] = table.standby_buffer_size
            else:
                continue

            seen_offset = offsets.get(tp, None)
            if seen_offset is None or offset > seen_offset:
                offsets[tp] = offset
                buf = buffers[table]
                buf.append(event)
                await table.on_changelog_event(event)
                if len(buf) >= bufsize:
                    table.apply_changelog_batch(buf)
                    buf.clear()
            if self.in_recovery and not self.active_remaining_total():
                # apply anything stuck in the buffers
                self.flush_buffers()
                self.in_recovery = False
                if self._actives_span is not None:
                    self._actives_span.set_tag('Actives-Ready', True)
                self.signal_recovery_end.set()
            if self.standbys_pending and not self.standby_remaining_total():
                if self._standbys_span:
                    finish_span(self._standbys_span)
                    self._standbys_span = None
                self.tables.on_standbys_ready()

    def flush_buffers(self) -> None:
        """Flush changelog buffers."""
        for table, buffer in self.buffers.items():
            table.apply_changelog_batch(buffer)
            buffer.clear()

    def need_recovery(self) -> bool:
        """Return :const:`True` if recovery is required."""
        return any(v for v in self.active_remaining().values())

    def active_remaining(self) -> Counter[TP]:
        """Return counter of remaining changes by active partition."""
        highwaters = self.active_highwaters
        offsets = self.active_offsets
        return Counter({
            tp: highwater - offsets[tp]
            for tp, highwater in highwaters.items()
            if highwater is not None and offsets[tp] is not None
        })

    def standby_remaining(self) -> Counter[TP]:
        """Return counter of remaining changes by standby partition."""
        highwaters = self.standby_highwaters
        offsets = self.standby_offsets
        return Counter({
            tp: highwater - offsets[tp]
            for tp, highwater in highwaters.items()
            if highwater >= 0 and offsets[tp] >= 0
        })

    def active_remaining_total(self) -> int:
        """Return number of changes remaining for actives to be up-to-date."""
        return sum(self.active_remaining().values())

    def standby_remaining_total(self) -> int:
        """Return number of changes remaining for standbys to be up-to-date."""
        return sum(self.standby_remaining().values())

    def active_stats(self) -> MutableMapping[TP, Tuple[int, int, int]]:
        """Return current active recovery statistics."""
        offsets = self.active_offsets
        return {
            tp: (highwater, offsets[tp], highwater - offsets[tp])
            for tp, highwater in self.active_highwaters.items()
            if offsets[tp] is not None and highwater - offsets[tp] != 0
        }

    def standby_stats(self) -> MutableMapping[TP, Tuple[int, int, int]]:
        """Return current standby recovery statistics."""
        offsets = self.standby_offsets
        return {
            tp: (highwater, offsets[tp], highwater - offsets[tp])
            for tp, highwater in self.standby_highwaters.items()
            if offsets[tp] is not None and highwater - offsets[tp] != 0
        }

    @Service.task
    async def _publish_stats(self) -> None:
        interval = self.stats_interval
        await self.sleep(interval)
        async for sleep_time in self.itertimer(interval,
                                               name='Recovery.stats'):
            if self.in_recovery:
                stats = self.active_stats()
                if stats:
                    self.log.info('Still fetching. Remaining: %s', stats)

    def _is_changelog_tp(self, tp: TP) -> bool:
        return tp.topic in self.tables.changelog_topics
Exemple #19
0
class Recovery(Service):
    """Service responsible for recovering tables from changelog topics."""

    app: AppT

    tables: _TableManager

    stats_interval: float = 5.0

    #: Set of standby topic partitions.
    standby_tps: Set[TP]

    #: Set of active topic partitions.
    active_tps: Set[TP]

    actives_for_table: MutableMapping[CollectionT, Set[TP]]
    standbys_for_table: MutableMapping[CollectionT, Set[TP]]

    #: Mapping from topic partition to table
    tp_to_table: MutableMapping[TP, CollectionT]

    #: Active offset by topic partition.
    active_offsets: Counter[TP]

    #: Standby offset by topic partition.
    standby_offsets: Counter[TP]

    #: Mapping of highwaters by topic partition.
    highwaters: Counter[TP]

    #: Active highwaters by topic partition.
    active_highwaters: Counter[TP]

    #: Standby highwaters by topic partition.
    standby_highwaters: Counter[TP]

    _signal_recovery_start: Optional[Event] = None
    _signal_recovery_end: Optional[Event] = None
    _signal_recovery_reset: Optional[Event] = None

    completed: Event
    in_recovery: bool = False
    standbys_pending: bool = False
    recovery_delay: float

    #: Changelog event buffers by table.
    #: These are filled by background task `_slurp_changelog`,
    #: and need to be flushed before starting new recovery/stopping.
    buffers: MutableMapping[CollectionT, List[EventT]]

    #: Cache of max buffer size by topic partition..
    buffer_sizes: MutableMapping[TP, int]

    #: Time in seconds after we warn that no flush has happened.
    flush_timeout_secs: float = 120.0

    #: Time in seconds after we warn that no events have been received.
    event_timeout_secs: float = 30.0

    #: Time of last event received by active TP
    _active_events_received_at: MutableMapping[TP, float]

    #: Time of last event received by standby TP
    _standby_events_received_at: MutableMapping[TP, float]

    #: Time of last event received (for any active TP)
    _last_active_event_processed_at: Optional[float]

    #: Time of last buffer flush
    _last_flush_at: Optional[float] = None

    #: Time when recovery last started
    _recovery_started_at: Optional[float] = None

    #: Time when recovery last ended
    _recovery_ended_at: Optional[float] = None

    _recovery_span: Optional[opentracing.Span] = None
    _actives_span: Optional[opentracing.Span] = None
    _standbys_span: Optional[opentracing.Span] = None

    #: List of last 100 processing timestamps (monotonic).
    #: Updated after processing every changelog record,
    #: used to estimate time remaining.
    _processing_times: Deque[float]

    #: Number of entries in _processing_times before
    #: we can give an estimate of time remaining.
    num_samples_required_for_estimate = 1000

    def __init__(self,
                 app: AppT,
                 tables: TableManagerT,
                 **kwargs: Any) -> None:
        self.app = app
        self.tables = cast(_TableManager, tables)

        self.standby_tps = set()
        self.active_tps = set()

        self.tp_to_table = {}
        self.active_offsets = Counter()
        self.standby_offsets = Counter()

        self.active_highwaters = Counter()
        self.standby_highwaters = Counter()
        self.completed = Event()

        self.buffers = defaultdict(list)
        self.buffer_sizes = {}
        self.recovery_delay = self.app.conf.stream_recovery_delay

        self.actives_for_table = defaultdict(set)
        self.standbys_for_table = defaultdict(set)

        self._active_events_received_at = {}
        self._standby_events_received_at = {}
        self._processing_times = deque()

        super().__init__(**kwargs)

    @property
    def signal_recovery_start(self) -> Event:
        """Event used to signal that recovery has started."""
        if self._signal_recovery_start is None:
            self._signal_recovery_start = Event(loop=self.loop)
        return self._signal_recovery_start

    @property
    def signal_recovery_end(self) -> Event:
        """Event used to signal that recovery has ended."""
        if self._signal_recovery_end is None:
            self._signal_recovery_end = Event(loop=self.loop)
        return self._signal_recovery_end

    @property
    def signal_recovery_reset(self) -> Event:
        """Event used to signal that recovery is restarting."""
        if self._signal_recovery_reset is None:
            self._signal_recovery_reset = Event(loop=self.loop)
        return self._signal_recovery_reset

    async def on_stop(self) -> None:
        """Call when recovery service stops."""
        # Flush buffers when stopping.
        self.flush_buffers()

    def add_active(self, table: CollectionT, tp: TP) -> None:
        """Add changelog partition to be used for active recovery."""
        self.active_tps.add(tp)
        self.actives_for_table[table].add(tp)
        self._add(table, tp, self.active_offsets)

    def add_standby(self, table: CollectionT, tp: TP) -> None:
        """Add changelog partition to be used for standby recovery."""
        self.standby_tps.add(tp)
        self.standbys_for_table[table].add(tp)
        self._add(table, tp, self.standby_offsets)

    def _add(self, table: CollectionT, tp: TP, offsets: Counter[TP]) -> None:
        self.tp_to_table[tp] = table
        persisted_offset = table.persisted_offset(tp)
        if persisted_offset is not None:
            offsets[tp] = persisted_offset
        offsets.setdefault(tp, None)  # type: ignore

    def revoke(self, tp: TP) -> None:
        """Revoke assignment of table changelog partition."""
        self.standby_offsets.pop(tp, None)
        self.standby_highwaters.pop(tp, None)
        self.active_offsets.pop(tp, None)
        self.active_highwaters.pop(tp, None)

    def on_partitions_revoked(self, revoked: Set[TP]) -> None:
        """Call when rebalancing and partitions are revoked."""
        T = traced_from_parent_span()
        T(self.flush_buffers)()
        self.signal_recovery_reset.set()

    async def on_rebalance(self,
                           assigned: Set[TP],
                           revoked: Set[TP],
                           newly_assigned: Set[TP]) -> None:
        """Call when cluster is rebalancing."""
        app = self.app
        assigned_standbys = app.assignor.assigned_standbys()
        assigned_actives = app.assignor.assigned_actives()

        for tp in revoked:
            await asyncio.sleep(0)
            self.revoke(tp)

        self.standby_tps.clear()
        self.active_tps.clear()
        self.actives_for_table.clear()
        self.standbys_for_table.clear()

        for tp in assigned_standbys:
            table = self.tables._changelogs.get(tp.topic)
            if table is not None:
                self.add_standby(table, tp)
            await asyncio.sleep(0)
        await asyncio.sleep(0)
        for tp in assigned_actives:
            table = self.tables._changelogs.get(tp.topic)
            if table is not None:
                self.add_active(table, tp)
            await asyncio.sleep(0)
        await asyncio.sleep(0)

        active_offsets = {
            tp: offset
            for tp, offset in self.active_offsets.items()
            if tp in self.active_tps
        }
        self.active_offsets.clear()
        self.active_offsets.update(active_offsets)

        await asyncio.sleep(0)

        rebalancing_span = cast(_App, self.app)._rebalancing_span
        if app.tracer and rebalancing_span:
            self._recovery_span = app.tracer.get_tracer('_faust').start_span(
                'recovery',
                child_of=rebalancing_span,
            )
            app._span_add_default_tags(self._recovery_span)
        self.signal_recovery_reset.clear()
        self.signal_recovery_start.set()

    async def _resume_streams(self) -> None:
        app = self.app
        consumer = app.consumer
        await app.on_rebalance_complete.send()
        # Resume partitions and start fetching.
        self.log.info('Resuming flow...')
        consumer.resume_flow()
        app.flow_control.resume()
        assignment = consumer.assignment()
        if assignment:
            self.log.info('Seek stream partitions to committed offsets.')
            await self._wait(consumer.perform_seek())
            self.log.dev('Resume stream partitions')
            consumer.resume_partitions(assignment)
        else:
            self.log.info('Resuming streams with empty assignment')
        self.completed.set()
        # finally make sure the fetcher is running.
        await cast(_App, app)._fetcher.maybe_start()
        self.tables.on_actives_ready()
        self.tables.on_standbys_ready()
        app.on_rebalance_end()
        self.log.info('Worker ready')

    @Service.task
    async def _restart_recovery(self) -> None:
        consumer = self.app.consumer
        active_tps = self.active_tps
        standby_tps = self.standby_tps
        standby_offsets = self.standby_offsets
        standby_highwaters = self.standby_highwaters
        assigned_active_tps = self.active_tps
        assigned_standby_tps = self.standby_tps
        active_offsets = self.active_offsets
        standby_offsets = self.standby_offsets
        active_highwaters = self.active_highwaters

        while not self.should_stop:
            self.log.dev('WAITING FOR NEXT RECOVERY TO START')
            self.signal_recovery_reset.clear()
            self._set_recovery_ended()
            if await self.wait_for_stopped(self.signal_recovery_start):
                self.signal_recovery_start.clear()
                break  # service was stopped
            self.signal_recovery_start.clear()

            span: Any = None
            spans: list = []
            tracer: Optional[opentracing.Tracer] = None
            if self.app.tracer:
                tracer = self.app.tracer.get_tracer('_faust')
            if tracer is not None and self._recovery_span:
                span = tracer.start_span(
                    'recovery-thread',
                    child_of=self._recovery_span)
                self.app._span_add_default_tags(span)
                spans.extend([span, self._recovery_span])
            T = traced_from_parent_span(span)

            try:
                await self._wait(T(asyncio.sleep)(self.recovery_delay))

                if not self.tables:
                    # If there are no tables -- simply resume streams
                    await T(self._resume_streams)()
                    for _span in spans:
                        finish_span(_span)
                    continue

                self._set_recovery_started()
                self.standbys_pending = True
                # Must flush any buffers before starting rebalance.
                T(self.flush_buffers)()
                producer = cast(_App, self.app)._producer
                if producer is not None:
                    await self._wait(T(producer.flush)())

                self.log.dev('Build highwaters for active partitions')
                await self._wait(T(self._build_highwaters)(
                    consumer, assigned_active_tps,
                    active_highwaters, 'active'))

                self.log.dev('Build offsets for active partitions')
                await self._wait(T(self._build_offsets)(
                    consumer, assigned_active_tps, active_offsets, 'active'))

                for tp in assigned_active_tps:
                    if active_offsets[tp] > active_highwaters[tp]:
                        raise ConsistencyError(
                            E_PERSISTED_OFFSET.format(
                                tp,
                                active_offsets[tp],
                                active_highwaters[tp],
                            ),
                        )

                self.log.dev('Build offsets for standby partitions')
                await self._wait(T(self._build_offsets)(
                    consumer, assigned_standby_tps,
                    standby_offsets, 'standby'))

                self.log.dev('Seek offsets for active partitions')
                await self._wait(T(self._seek_offsets)(
                    consumer, assigned_active_tps, active_offsets, 'active'))

                if self.need_recovery():
                    self.log.info('Restoring state from changelog topics...')
                    T(consumer.resume_partitions)(active_tps)
                    # Resume partitions and start fetching.
                    self.log.info('Resuming flow...')
                    T(consumer.resume_flow)()
                    await T(cast(_App, self.app)._fetcher.maybe_start)()
                    T(self.app.flow_control.resume)()

                    # Wait for actives to be up to date.
                    # This signal will be set by _slurp_changelogs
                    if tracer is not None and span:
                        self._actives_span = tracer.start_span(
                            'recovery-actives',
                            child_of=span,
                            tags={'Active-Stats': self.active_stats()},
                        )
                        self.app._span_add_default_tags(span)
                    try:
                        self.signal_recovery_end.clear()
                        await self._wait(self.signal_recovery_end)
                    except Exception as exc:
                        finish_span(self._actives_span, error=exc)
                    else:
                        finish_span(self._actives_span)
                    finally:
                        self._actives_span = None

                    # recovery done.
                    self.log.info('Done reading from changelog topics')
                    T(consumer.pause_partitions)(active_tps)
                else:
                    self.log.info('Resuming flow...')
                    T(consumer.resume_flow)()
                    T(self.app.flow_control.resume)()

                self.log.info('Recovery complete')
                if span:
                    span.set_tag('Recovery-Completed', True)
                self._set_recovery_ended()

                if standby_tps:
                    self.log.info('Starting standby partitions...')

                    self.log.dev('Seek standby offsets')
                    await self._wait(
                        T(self._seek_offsets)(
                            consumer, standby_tps, standby_offsets, 'standby'))

                    self.log.dev('Build standby highwaters')
                    await self._wait(
                        T(self._build_highwaters)(
                            consumer,
                            standby_tps,
                            standby_highwaters,
                            'standby',
                        ),
                    )

                    for tp in standby_tps:
                        if standby_offsets[tp] > standby_highwaters[tp]:
                            raise ConsistencyError(
                                E_PERSISTED_OFFSET.format(
                                    tp,
                                    standby_offsets[tp],
                                    standby_highwaters[tp],
                                ),
                            )

                    if tracer is not None and span:
                        self._standbys_span = tracer.start_span(
                            'recovery-standbys',
                            child_of=span,
                            tags={'Standby-Stats': self.standby_stats()},
                        )
                        self.app._span_add_default_tags(span)
                    self.log.dev('Resume standby partitions')
                    T(consumer.resume_partitions)(standby_tps)

                # Pause all our topic partitions,
                # to make sure we don't fetch any more records from them.
                await self._wait(asyncio.sleep(0.1))  # still needed?
                await self._wait(T(self.on_recovery_completed)())
            except RebalanceAgain as exc:
                self.log.dev('RAISED REBALANCE AGAIN')
                for _span in spans:
                    finish_span(_span, error=exc)
                continue  # another rebalance started
            except ServiceStopped as exc:
                self.log.dev('RAISED SERVICE STOPPED')
                for _span in spans:
                    finish_span(_span, error=exc)
                break  # service was stopped
            except Exception as exc:
                for _span in spans:
                    finish_span(_span, error=exc)
                raise
            else:
                for _span in spans:
                    finish_span(_span)
            # restart - wait for next rebalance.
        self._set_recovery_ended()

    def _set_recovery_started(self) -> None:
        self.in_recovery = True
        self._recovery_ended = None
        self._recovery_started_at = monotonic()
        self._active_events_received_at.clear()
        self._standby_events_received_at.clear()
        self._processing_times.clear()
        self._last_active_event_processed_at = None

    def _set_recovery_ended(self) -> None:
        self.in_recovery = False
        self._recovery_ended_at = monotonic()
        self._active_events_received_at.clear()
        self._standby_events_received_at.clear()
        self._processing_times.clear()
        self._last_active_event_processed_at = None

    def active_remaining_seconds(self, remaining: float) -> str:
        s = self._estimated_active_remaining_secs(remaining)
        return humanize_seconds(s, now='none') if s else '???'

    def _estimated_active_remaining_secs(
            self, remaining: float) -> Optional[float]:
        processing_times = self._processing_times
        if len(processing_times) >= self.num_samples_required_for_estimate:
            mean_time = statistics.mean(processing_times)
            return (mean_time * remaining) * 1.10  # add 10%
        else:
            return None

    async def _wait(self, coro: WaitArgT) -> None:
        wait_result = await self.wait_first(
            coro,
            self.signal_recovery_reset,
            self.signal_recovery_start,
        )
        if wait_result.stopped:
            # service was stopped.
            raise ServiceStopped()
        elif self.signal_recovery_start in wait_result.done:
            # another rebalance started
            raise RebalanceAgain()
        elif self.signal_recovery_reset in wait_result.done:
            raise RebalanceAgain()
        else:
            return None

    async def on_recovery_completed(self) -> None:
        """Call when active table recovery is completed."""
        consumer = self.app.consumer
        self.log.info('Restore complete!')
        await self.app.on_rebalance_complete.send()
        # This needs to happen if all goes well
        callback_coros = [
            table.on_recovery_completed(
                self.actives_for_table[table],
                self.standbys_for_table[table],
            )
            for table in self.tables.values()
        ]
        if callback_coros:
            await asyncio.wait(callback_coros)
        assignment = consumer.assignment()
        if assignment:
            self.log.info('Seek stream partitions to committed offsets.')
            await consumer.perform_seek()
        self.completed.set()
        self.log.dev('Resume stream partitions')
        consumer.resume_partitions({
            tp for tp in assignment
            if not self._is_changelog_tp(tp)
        })
        # finally make sure the fetcher is running.
        await cast(_App, self.app)._fetcher.maybe_start()
        self.tables.on_actives_ready()
        if not self.app.assignor.assigned_standbys():
            self.tables.on_standbys_ready()
        self.app.on_rebalance_end()
        self.log.info('Worker ready')

    async def _build_highwaters(self,
                                consumer: ConsumerT,
                                tps: Set[TP],
                                destination: Counter[TP],
                                title: str) -> None:
        # -- Build highwater
        highwaters = await consumer.highwaters(*tps)
        highwaters = {
            # FIXME the -1 here is because of the way we commit offsets
            tp: value - 1 if value is not None else -1
            for tp, value in highwaters.items()
        }
        self.log.info(
            'Highwater for %s changelog partitions:\n%s',
            title, self._highwater_logtable(highwaters, title=title))
        destination.clear()
        destination.update(highwaters)

    def _highwater_logtable(self, highwaters: Mapping[TP, int], *,
                            title: str) -> str:
        table_data = [
            [k.topic, str(k.partition), str(v)]
            for k, v in sorted(highwaters.items())
        ]
        return terminal.logtable(
            list(self._consolidate_table_keys(table_data)),
            title=f'Highwater - {title.capitalize()}',
            headers=['topic', 'partition', 'highwater'],
        )

    def _consolidate_table_keys(self, data: TableDataT) -> Iterator[List[str]]:
        """Format terminal log table to reduce noise from duplicate keys.

        We log tables where the first row is the name of the topic,
        and it gets noisy when that name is repeated over and over.

        This function replaces repeating topic names
        with the ditto mark.

        Note:
            Data must be sorted.
        """
        prev_key: Optional[str] = None
        for key, *rest in data:
            if prev_key is not None and prev_key == key:
                yield ['〃', *rest]  # ditto
            else:
                yield [key, *rest]
            prev_key = key

    async def _build_offsets(self,
                             consumer: ConsumerT,
                             tps: Set[TP],
                             destination: Counter[TP],
                             title: str) -> None:
        # -- Update offsets
        # Offsets may have been compacted, need to get to the recent ones
        earliest = await consumer.earliest_offsets(*tps)
        # FIXME To be consistent with the offset -1 logic
        earliest = {tp: offset - 1 for tp, offset in earliest.items()}
        for tp in tps:
            last_value = destination[tp]
            new_value = earliest[tp]

            if last_value is None:
                destination[tp] = new_value
            elif new_value is None:
                destination[tp] = last_value
            else:
                destination[tp] = max(last_value, new_value)
        self.log.info(
            '%s offsets at start of reading:\n%s',
            title,
            self._start_offsets_logtable(destination, title=title),
        )

    def _start_offsets_logtable(self, offsets: Mapping[TP, int], *,
                                title: str) -> str:
        table_data = [
            [k.topic, str(k.partition), str(v)]
            for k, v in sorted(offsets.items())
        ]
        return terminal.logtable(
            list(self._consolidate_table_keys(table_data)),
            title=f'Reading Starts At - {title.capitalize()}',
            headers=['topic', 'partition', 'offset'],
        )

    async def _seek_offsets(self,
                            consumer: ConsumerT,
                            tps: Set[TP],
                            offsets: Counter[TP],
                            title: str) -> None:
        # Seek to new offsets
        new_offsets = {}
        for tp in tps:
            offset = offsets[tp]
            if offset == -1:
                offset = 0
            new_offsets[tp] = offset
        # FIXME Remove check when fixed offset-1 discrepancy
        await consumer.seek_wait(new_offsets)

    @Service.task
    async def _slurp_changelogs(self) -> None:
        changelog_queue = self.tables.changelog_queue
        tp_to_table = self.tp_to_table

        active_tps = self.active_tps
        standby_tps = self.standby_tps
        active_offsets = self.active_offsets
        standby_offsets = self.standby_offsets
        active_events_received_at = self._active_events_received_at
        standby_events_received_at = self._standby_events_received_at

        buffers = self.buffers
        buffer_sizes = self.buffer_sizes
        processing_times = self._processing_times

        def _maybe_signal_recovery_end() -> None:
            if self.in_recovery and not self.active_remaining_total():
                # apply anything stuck in the buffers
                self.flush_buffers()
                self._set_recovery_ended()
                if self._actives_span is not None:
                    self._actives_span.set_tag('Actives-Ready', True)
                self.signal_recovery_end.set()

        while not self.should_stop:
            try:
                event: EventT = await asyncio.wait_for(
                    changelog_queue.get(), timeout=5.0)
            except asyncio.TimeoutError:
                if self.should_stop:
                    return
                _maybe_signal_recovery_end()
                continue

            now = monotonic()
            message = event.message
            tp = message.tp
            offset = message.offset

            offsets: Counter[TP]
            bufsize = buffer_sizes.get(tp)
            is_active = False
            if tp in active_tps:
                is_active = True
                table = tp_to_table[tp]
                offsets = active_offsets
                if bufsize is None:
                    bufsize = buffer_sizes[tp] = table.recovery_buffer_size
                active_events_received_at[tp] = now
            elif tp in standby_tps:
                table = tp_to_table[tp]
                offsets = standby_offsets
                if bufsize is None:
                    bufsize = buffer_sizes[tp] = table.standby_buffer_size
                    standby_events_received_at[tp] = now
            else:
                continue

            seen_offset = offsets.get(tp, None)
            if seen_offset is None or offset > seen_offset:
                offsets[tp] = offset
                buf = buffers[table]
                buf.append(event)
                await table.on_changelog_event(event)
                if len(buf) >= bufsize:
                    table.apply_changelog_batch(buf)
                    buf.clear()
                    self._last_flush_at = now
                now_after = monotonic()

                if is_active:
                    last_processed_at = self._last_active_event_processed_at
                    if last_processed_at is not None:
                        processing_times.append(now_after - last_processed_at)
                        max_samples = self.num_samples_required_for_estimate
                        if len(processing_times) > max_samples:
                            processing_times.popleft()
                    self._last_active_event_processed_at = now_after

            _maybe_signal_recovery_end()

            if self.standbys_pending and not self.standby_remaining_total():
                if self._standbys_span:
                    finish_span(self._standbys_span)
                    self._standbys_span = None
                self.tables.on_standbys_ready()

    def flush_buffers(self) -> None:
        """Flush changelog buffers."""
        for table, buffer in self.buffers.items():
            table.apply_changelog_batch(buffer)
            buffer.clear()
        self._last_flush_at = monotonic()

    def need_recovery(self) -> bool:
        """Return :const:`True` if recovery is required."""
        return any(v for v in self.active_remaining().values())

    def active_remaining(self) -> Counter[TP]:
        """Return counter of remaining changes by active partition."""
        highwaters = self.active_highwaters
        offsets = self.active_offsets
        return Counter({
            tp: highwater - offsets[tp]
            for tp, highwater in highwaters.items()
            if highwater is not None and offsets[tp] is not None
        })

    def standby_remaining(self) -> Counter[TP]:
        """Return counter of remaining changes by standby partition."""
        highwaters = self.standby_highwaters
        offsets = self.standby_offsets
        return Counter({
            tp: highwater - offsets[tp]
            for tp, highwater in highwaters.items()
            if highwater >= 0 and offsets[tp] >= 0
        })

    def active_remaining_total(self) -> int:
        """Return number of changes remaining for actives to be up-to-date."""
        return sum(self.active_remaining().values())

    def standby_remaining_total(self) -> int:
        """Return number of changes remaining for standbys to be up-to-date."""
        return sum(self.standby_remaining().values())

    def active_stats(self) -> RecoveryStatsMapping:
        """Return current active recovery statistics."""
        offsets = self.active_offsets
        return {
            tp: RecoveryStats(highwater,
                              offsets[tp],
                              highwater - offsets[tp])
            for tp, highwater in self.active_highwaters.items()
            if offsets[tp] is not None and highwater - offsets[tp] != 0
        }

    def standby_stats(self) -> RecoveryStatsMapping:
        """Return current standby recovery statistics."""
        offsets = self.standby_offsets
        return {
            tp: RecoveryStats(highwater,
                              offsets[tp],
                              highwater - offsets[tp])
            for tp, highwater in self.standby_highwaters.items()
            if offsets[tp] is not None and highwater - offsets[tp] != 0

        }

    def _stats_to_logtable(
            self,
            title: str,
            stats: RecoveryStatsMapping) -> str:
        table_data = [
            list(map(str, [
                tp.topic,
                tp.partition,
                s.highwater,
                s.offset,
                s.remaining,
            ])) for tp, s in sorted(stats.items())
        ]
        return terminal.logtable(
            list(self._consolidate_table_keys(table_data)),
            title=title,
            headers=[
                'topic',
                'partition',
                'need offset',
                'have offset',
                'remaining',
            ],
        )

    @Service.task
    async def _publish_stats(self) -> None:
        """Emit stats (remaining to fetch) while in active recovery."""
        interval = self.stats_interval
        await self.sleep(interval)
        async for sleep_time in self.itertimer(
                interval, name='Recovery.stats'):
            if self.in_recovery:
                now = monotonic()
                stats = self.active_stats()
                num_samples = len(self._processing_times)
                if stats and \
                        num_samples >= self.num_samples_required_for_estimate:
                    remaining_total = self.active_remaining_total()
                    self.log.info(
                        'Still fetching changelog topics for recovery, '
                        'estimated time remaining %s '
                        '(total remaining=%r):\n%s',
                        self.active_remaining_seconds(remaining_total),
                        remaining_total,
                        self._stats_to_logtable(
                            'Remaining for active recovery', stats),
                    )
                elif stats:
                    await self._verify_remaining(now, stats)
                else:
                    recovery_started_at = self._recovery_started_at
                    if recovery_started_at is None:
                        self.log.error(
                            'POSSIBLE INTERNAL ERROR: '
                            'Recovery marked as started but missing '
                            'self._recovery_started_at timestamp.')
                    else:
                        secs_since_started = now - recovery_started_at
                        if secs_since_started >= 30.0:
                            # This shouldn't happen, but we want to
                            # log an error in case it does.
                            self.log.error(
                                'POSSIBLE INTERNAL ERROR: '
                                'Recovery has no remaining offsets to fetch, '
                                'but we have spent %s waiting for the worker '
                                'to transition out of recovery state...',
                                humanize_seconds(secs_since_started),
                            )

    async def _verify_remaining(
            self,
            now: float,
            stats: RecoveryStatsMapping) -> None:
        consumer = self.app.consumer
        active_events_received_at = self._active_events_received_at
        recovery_started_at = self._recovery_started_at
        if recovery_started_at is None:
            return  # we already log about this in _publish_stats
        secs_since_started = now - recovery_started_at

        last_flush_at = self._last_flush_at
        if last_flush_at is None:
            if secs_since_started >= self.flush_timeout_secs:
                self.log.warning(
                    'Recovery has not flushed buffers since '
                    'recovery startted (started %s). '
                    'Current total buffer size: %r',
                    humanize_seconds_ago(secs_since_started),
                    self._current_total_buffer_size(),
                )
        else:
            secs_since_last_flush = now - last_flush_at
            if secs_since_last_flush >= self.flush_timeout_secs:
                self.log.warning(
                    'Recovery has not flushed buffers in the last %r '
                    'seconds (last flush was %s). '
                    'Current total buffer size: %r',
                    self.flush_timeout_secs,
                    humanize_seconds_ago(secs_since_last_flush),
                    self._current_total_buffer_size(),
                )

        for tp in stats:
            await self.sleep(0)
            if self.should_stop:
                break
            if not self.in_recovery:
                break
            consumer.verify_recovery_event_path(now, tp)
            secs_since_started = now - recovery_started_at

            last_event_received = active_events_received_at.get(tp)
            if last_event_received is None:
                if secs_since_started >= self.event_timeout_secs:
                    self.log.warning(
                        'No event received for active tp %r since recovery '
                        'start (started %s)',
                        tp, humanize_seconds_ago(secs_since_started),
                    )
                continue

            secs_since_received = now - last_event_received
            if secs_since_received >= self.event_timeout_secs:
                self.log.warning(
                    'No event received for active tp %r in the last %r '
                    'seconds (last event received %s)',
                    tp, self.event_timeout_secs,
                    humanize_seconds_ago(secs_since_received),
                )

    def _current_total_buffer_size(self) -> int:
        return sum(len(buf) for buf in self.buffers.values())

    def _is_changelog_tp(self, tp: TP) -> bool:
        return tp.topic in self.tables.changelog_topics