async def test_wait_empty_queue():
    q = TaskQueue()
    try:
        await wait(q.get())
    except asyncio.TimeoutError:
        pass
    else:
        assert False, "should not return from get() when nothing is available on queue"
async def test_get_nowait(tasks, get_size, expected_tasks):
    q = TaskQueue()
    await q.add(tasks)

    batch, tasks = q.get_nowait(get_size)

    assert tasks == expected_tasks

    q.complete(batch, tasks)

    assert all(task not in q for task in tasks)
    def __init__(self,
                 chain: AsyncChain,
                 db: AsyncHeaderDB,
                 peer_pool: AnyPeerPool,
                 token: CancelToken = None) -> None:
        super().__init__(token)
        self.chain = chain
        self.db = db
        self.peer_pool = peer_pool
        self._handler = PeerRequestHandler(self.db, self.logger, self.cancel_token)
        self._sync_requests: asyncio.Queue[HeaderRequestingPeer] = asyncio.Queue()
        self._peer_header_syncer: 'PeerHeaderSyncer' = None
        self._last_target_header_hash = None

        # pending queue size should be big enough to avoid starving the processing consumers, but
        # small enough to avoid wasteful over-requests before post-processing can happen
        max_pending_headers = HLSPeer.max_headers_fetch * 8
        self.header_queue = TaskQueue(max_pending_headers, attrgetter('block_number'))
async def test_unfinished_tasks_readded():
    q = TaskQueue()
    await wait(q.add((2, 1, 3)))

    batch, tasks = await wait(q.get())

    q.complete(batch, (2, ))

    batch, tasks = await wait(q.get())

    assert tasks == (1, 3)
async def test_cannot_complete_batch_with_wrong_task():
    q = TaskQueue()

    await wait(q.add((1, 2)))

    batch, tasks = await wait(q.get())

    # cannot complete a valid task with a task it wasn't given
    with pytest.raises(ValidationError):
        q.complete(batch, (3, 4))

    # partially invalid completion calls leave the valid task in an incomplete state
    with pytest.raises(ValidationError):
        q.complete(batch, (1, 3))

    assert 1 in q
async def test_cannot_complete_batch_unless_pending():
    q = TaskQueue()

    await wait(q.add((1, 2)))

    # cannot complete a valid task without a batch id
    with pytest.raises(ValidationError):
        q.complete(None, (1, 2))

    assert 1 in q

    batch, tasks = await wait(q.get())

    # cannot complete a valid task with an invalid batch id
    with pytest.raises(ValidationError):
        q.complete(batch + 1, (1, 2))

    assert 1 in q
async def test_queue_size_reset_after_complete():
    q = TaskQueue(maxsize=2)

    await wait(q.add((1, 2)))

    batch, tasks = await wait(q.get())

    # there should not be room to add another task
    try:
        await wait(q.add((3, )))
    except asyncio.TimeoutError:
        pass
    else:
        assert False, "should not be able to add task past maxsize"

    # do imaginary work here, then complete it all

    q.complete(batch, tasks)

    # there should be room to add more now
    await wait(q.add((3, )))
async def test_queue_get_cap(start_tasks, get_max, expected, remainder):
    q = TaskQueue()

    await wait(q.add(start_tasks))

    batch, tasks = await wait(q.get(get_max))
    assert tasks == expected

    if remainder:
        _, tasks2 = await wait(q.get())
        assert tasks2 == remainder
    else:
        try:
            _, tasks2 = await wait(q.get())
        except asyncio.TimeoutError:
            pass
        else:
            assert False, f"No more tasks to get, but got {tasks2!r}"
async def test_queue_contains_task_until_complete(tasks):
    q = TaskQueue(order_fn=id)

    first_task = tasks[0]

    assert first_task not in q

    await wait(q.add(tasks))

    assert first_task in q

    batch, pending_tasks = await wait(q.get())

    assert first_task in q

    q.complete(batch, pending_tasks)

    assert first_task not in q
async def test_two_pending_adds_one_release():
    q = TaskQueue(2)

    asyncio.ensure_future(q.add((3, 1, 2)))

    # wait for ^ to run and pause
    await asyncio.sleep(0)
    # note that the highest-priority items are queued first
    assert 1 in q
    assert 2 in q
    assert 3 not in q

    # two tasks are queued, none are started
    assert len(q) == 2
    assert q.num_in_progress() == 0

    asyncio.ensure_future(q.add((0, 4)))
    # wait for ^ to run and pause
    await asyncio.sleep(0)

    # task consumer 1 completes the first two pending
    batch, tasks = await wait(q.get())
    assert tasks == (1, 2)

    # both tasks started
    assert len(q) == 2
    assert q.num_in_progress() == 2

    q.complete(batch, tasks)

    # tasks are drained, but new ones aren't added yet...
    assert q.num_in_progress() == 0
    assert len(q) == 0

    await asyncio.sleep(0.01)

    # Now the tasks are added
    assert q.num_in_progress() == 0
    assert len(q) == 2

    # task consumer 2 gets the next two, in priority order
    batch, tasks = await wait(q.get())

    assert len(tasks) == 2

    assert tasks == (0, 3)

    assert q.num_in_progress() == 2
    assert len(q) == 2

    # clean up, so the pending get() call can complete
    q.complete(batch, tasks)

    # All current tasks finished
    assert q.num_in_progress() == 0

    await asyncio.sleep(0)

    # only task 4 remains
    assert q.num_in_progress() == 0
    assert len(q) == 1
async def test_cannot_add_single_non_tuple_task():
    q = TaskQueue()
    with pytest.raises(ValidationError):
        await wait(q.add(1))
async def test_invalid_priority_order(order_fn):
    q = TaskQueue(order_fn=order_fn)

    with pytest.raises(ValidationError):
        await wait(q.add((1, )))
async def test_valid_priority_order(order_fn):
    q = TaskQueue(order_fn=order_fn)

    # this just needs to not crash, when testing sortability
    await wait(q.add((1, )))
async def test_custom_priority_order():
    q = TaskQueue(maxsize=4, order_fn=lambda x: 0 - x)

    await wait(q.add((2, 1, 3)))
    (batch, tasks) = await wait(q.get())
    assert tasks == (3, 2, 1)
async def test_default_priority_order():
    q = TaskQueue(maxsize=4)
    await wait(q.add((2, 1, 3)))
    (batch, tasks) = await wait(q.get())
    assert tasks == (1, 2, 3)
async def test_cannot_readd_same_task():
    q = TaskQueue()
    await q.add((1, 2))
    with pytest.raises(ValidationError):
        await q.add((2,))
def test_get_nowait_queuefull(get_size):
    q = TaskQueue()
    with pytest.raises(asyncio.QueueFull):
        q.get_nowait(get_size)
async def test_unlimited_queue_by_default():
    q = TaskQueue()
    await wait(q.add(tuple(range(100001))))
async def test_no_asyncio_exception_leaks(operations, queue_size, add_size, get_size, event_loop):
    """
    This could be made much more general, at the cost of simplicity.
    For now, this mimics real usage enough to hopefully catch the big issues.

    Some examples for more generality:

    - different get sizes on each call
    - complete varying amounts of tasks at each call
    """

    async def getter(queue, num_tasks, get_event, complete_event, cancel_token):
        with trap_operation_cancelled():
            # wait to run the get
            await cancel_token.cancellable_wait(get_event.wait())

            batch, tasks = await cancel_token.cancellable_wait(
                queue.get(num_tasks)
            )
            get_event.clear()

            # wait to run the completion
            await cancel_token.cancellable_wait(complete_event.wait())

            queue.complete(batch, tasks)
            complete_event.clear()

    async def adder(queue, add_size, add_event, cancel_token):
        with trap_operation_cancelled():
            # wait to run the add
            await cancel_token.cancellable_wait(add_event.wait())

            await cancel_token.cancellable_wait(
                queue.add(tuple(random.randint(0, 2 ** 32) for _ in range(add_size)))
            )
            add_event.clear()

    async def operation_order(operations, events, cancel_token):
        for operation_id, pause in operations:
            events[operation_id].set()
            if pause:
                await asyncio.sleep(0)

        await asyncio.sleep(0)
        cancel_token.trigger()

    q = TaskQueue(queue_size)
    events = tuple(Event() for _ in range(6))
    add_event, add2_event, get_event, get2_event, complete_event, complete2_event = events
    cancel_token = CancelToken('end test')

    done, pending = await asyncio.wait([
        getter(q, get_size, get_event, complete_event, cancel_token),
        getter(q, get_size, get2_event, complete2_event, cancel_token),
        adder(q, add_size, add_event, cancel_token),
        adder(q, add_size, add2_event, cancel_token),
        operation_order(operations, events, cancel_token),
    ], return_when=asyncio.FIRST_EXCEPTION)

    for task in done:
        exc = task.exception()
        if exc:
            raise exc

    assert not pending
Exemple #20
0
class BaseHeaderChainSyncer(BaseService, PeerSubscriber):
    """
    Sync with the Ethereum network by fetching/storing block headers.

    Here, the run() method will execute the sync loop until our local head is the same as the one
    with the highest TD announced by any of our peers.
    """
    # We'll only sync if we are connected to at least min_peers_to_sync.
    min_peers_to_sync = 1
    # the latest header hash of the peer on the current sync
    header_queue: TaskQueue[BlockHeader]

    def __init__(self,
                 chain: AsyncChain,
                 db: AsyncHeaderDB,
                 peer_pool: AnyPeerPool,
                 token: CancelToken = None) -> None:
        super().__init__(token)
        self.chain = chain
        self.db = db
        self.peer_pool = peer_pool
        self._handler = PeerRequestHandler(self.db, self.logger, self.cancel_token)
        self._sync_requests: asyncio.Queue[HeaderRequestingPeer] = asyncio.Queue()
        self._peer_header_syncer: 'PeerHeaderSyncer' = None
        self._last_target_header_hash = None

        # pending queue size should be big enough to avoid starving the processing consumers, but
        # small enough to avoid wasteful over-requests before post-processing can happen
        max_pending_headers = HLSPeer.max_headers_fetch * 8
        self.header_queue = TaskQueue(max_pending_headers, attrgetter('block_number'))

    @property
    def msg_queue_maxsize(self) -> int:
        # This is a rather arbitrary value, but when the sync is operating normally we never see
        # the msg queue grow past a few hundred items, so this should be a reasonable limit for
        # now.
        return 2000

    def get_target_header_hash(self) -> Hash32:
        if self._peer_header_syncer is None and self._last_target_header_hash is None:
            raise ValidationError("Cannot check the target hash before a sync has run")
        elif self._peer_header_syncer is not None:
            return self._peer_header_syncer.get_target_header_hash()
        else:
            return self._last_target_header_hash

    def register_peer(self, peer: BasePeer) -> None:
        self._sync_requests.put_nowait(cast(HeaderRequestingPeer, self.peer_pool.highest_td_peer))

    async def _handle_msg_loop(self) -> None:
        while self.is_operational:
            peer, cmd, msg = await self.wait(self.msg_queue.get())
            # Our handle_msg() method runs cpu-intensive tasks in sub-processes so that the main
            # loop can keep processing msgs, and that's why we use self.run_task() instead of
            # awaiting for it to finish here.
            self.run_task(self.handle_msg(cast(HeaderRequestingPeer, peer), cmd, msg))

    async def handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command,
                         msg: protocol._DecodedMsgType) -> None:
        try:
            await self._handle_msg(peer, cmd, msg)
        except OperationCancelled:
            # Silently swallow OperationCancelled exceptions because otherwise they'll be caught
            # by the except below and treated as unexpected.
            pass
        except Exception:
            self.logger.exception("Unexpected error when processing msg from %s", peer)

    async def _run(self) -> None:
        self.run_task(self._handle_msg_loop())
        with self.subscribe(self.peer_pool):
            while self.is_operational:
                try:
                    peer = await self.wait(self._sync_requests.get())
                except OperationCancelled:
                    # In the case of a fast sync, we return once the sync is completed, and our
                    # caller must then run the StateDownloader.
                    return
                else:
                    self.run_task(self.sync(peer))

    @property
    def _syncing(self) -> bool:
        return self._peer_header_syncer is not None

    @contextmanager
    def _get_peer_header_syncer(self, peer: HeaderRequestingPeer) -> Iterator['PeerHeaderSyncer']:
        if self._syncing:
            raise ValidationError("Cannot sync headers from two peers at the same time")

        self._peer_header_syncer = PeerHeaderSyncer(
            self.chain,
            self.db,
            peer,
            self.cancel_token,
        )
        self.run_child_service(self._peer_header_syncer)
        try:
            yield self._peer_header_syncer
        except OperationCancelled:
            pass
        else:
            self._peer_header_syncer.cancel_nowait()
        finally:
            self.logger.info("Header Sync with %s ended", peer)
            self._last_target_header_hash = self._peer_header_syncer.get_target_header_hash()
            self._peer_header_syncer = None

    async def sync(self, peer: HeaderRequestingPeer) -> None:
        if self._syncing:
            self.logger.debug(
                "Got a NewBlock or a new peer, but already syncing so doing nothing")
            return
        elif len(self.peer_pool) < self.min_peers_to_sync:
            self.logger.info(
                "Connected to less peers (%d) than the minimum (%d) required to sync, "
                "doing nothing", len(self.peer_pool), self.min_peers_to_sync)
            return

        with self._get_peer_header_syncer(peer) as syncer:
            async for header_batch in syncer.next_header_batch():
                new_headers = tuple(h for h in header_batch if h not in self.header_queue)
                await self.wait(self.header_queue.add(new_headers))

    @abstractmethod
    async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command,
                          msg: protocol._DecodedMsgType) -> None:
        raise NotImplementedError("Must be implemented by subclasses")