Exemplo n.º 1
0
    def __init__(self,
                 chain: AsyncChain,
                 db: AsyncHeaderDB,
                 peer_pool: PeerPool,
                 token: CancelToken = None) -> None:
        self.complete_token = CancelToken(
            'trinity.sync.common.BaseHeaderChainSyncer.SyncCompleted')
        if token is None:
            master_service_token = self.complete_token
        else:
            master_service_token = token.chain(self.complete_token)
        super().__init__(master_service_token)
        self.chain = chain
        self.db = db
        self.peer_pool = peer_pool
        self._handler = PeerRequestHandler(self.db, self.logger,
                                           self.cancel_token)
        self._syncing = False
        self._sync_complete = asyncio.Event()
        self._sync_requests: asyncio.Queue[
            HeaderRequestingPeer] = asyncio.Queue()

        # pending queue size should be big enough to avoid starving the processing consumers, but
        # small enough to avoid wasteful over-requests before post-processing can happen
        max_pending_headers = ETHPeer.max_headers_fetch * 8
        self.header_queue = TaskQueue(max_pending_headers,
                                      attrgetter('block_number'))
Exemplo n.º 2
0
async def test_wait_cancel_pending_tasks_on_completion(event_loop):
    token = CancelToken('token')
    token2 = CancelToken('token2')
    chain = token.chain(token2)
    event_loop.call_soon(token2.trigger)
    await chain.wait()
    await assert_only_current_task_not_done()
Exemplo n.º 3
0
def test_token_chain_trigger_last():
    token = CancelToken('token')
    token2 = CancelToken('token2')
    token3 = CancelToken('token3')
    chain = token.chain(token2).chain(token3)
    assert not chain.triggered
    token3.trigger()
    assert chain.triggered
    assert chain.triggered_token == token3
Exemplo n.º 4
0
def test_token_chain_trigger_middle():
    token = CancelToken('token')
    token2 = CancelToken('token2')
    token3 = CancelToken('token3')
    intermediate_chain = token.chain(token2)
    chain = intermediate_chain.chain(token3)
    assert not chain.triggered
    token2.trigger()
    assert chain.triggered
    assert intermediate_chain.triggered
    assert chain.triggered_token == token2
    assert not token3.triggered
    assert not token.triggered
Exemplo n.º 5
0
    def __init__(self,
                 transport: TransportAPI,
                 base_protocol: BaseP2PProtocol,
                 protocols: Sequence[ProtocolAPI],
                 token: CancelToken = None,
                 max_queue_size: int = 4096) -> None:
        self.logger = get_logger('p2p.multiplexer.Multiplexer')
        if token is None:
            loop = None
        else:
            loop = token.loop
        base_token = CancelToken(f'multiplexer[{transport.remote}]', loop=loop)

        if token is None:
            self.cancel_token = base_token
        else:
            self.cancel_token = base_token.chain(token)

        self._transport = transport
        # the base `p2p` protocol instance.
        self._base_protocol = base_protocol

        # the sub-protocol instances
        self._protocols = protocols

        # Lock to ensure that multiple call sites cannot concurrently stream
        # messages.
        self._multiplex_lock = asyncio.Lock()

        # Lock management on a per-protocol basis to ensure we only have one
        # stream consumer for each protocol.
        self._protocol_locks = {
            type(protocol): asyncio.Lock()
            for protocol
            in self.get_protocols()
        }

        # Each protocol gets a queue where messages for the individual protocol
        # are placed when streamed from the transport
        self._protocol_queues = {
            type(protocol): asyncio.Queue(max_queue_size)
            for protocol
            in self.get_protocols()
        }

        self._msg_counts = collections.defaultdict(int)
        self._last_msg_time = 0
Exemplo n.º 6
0
    def __init__(self,
                 token: CancelToken = None,
                 loop: asyncio.AbstractEventLoop = None) -> None:
        self.events = ServiceEvents()
        self._run_lock = asyncio.Lock()
        self._child_services = WeakSet()
        self._tasks = WeakSet()
        self._finished_callbacks = []

        self._loop = loop

        base_token = CancelToken(type(self).__name__, loop=loop)

        if token is None:
            self.cancel_token = base_token
        else:
            self.cancel_token = base_token.chain(token)
Exemplo n.º 7
0
async def test_token_wait_finishing_cleans_up_chained_token_waits():
    parent = CancelToken('parent')
    child = parent.chain(CancelToken('child'))

    # this schedules both the child's `wait()` and the parent's `wait()`
    fut = asyncio.ensure_future(child.wait())
    # yield for a moment to let them spin up.
    await asyncio.sleep(0.01)

    # ensure that there are some not-done tasks
    with pytest.raises(AssertionError):
        await assert_only_current_task_not_done()

    child.trigger()

    await asyncio.wait_for(fut, timeout=0.01)

    await assert_only_current_task_not_done()
Exemplo n.º 8
0
    def __init__(self,
                 token: CancelToken = None,
                 loop: asyncio.AbstractEventLoop = None) -> None:
        if self.logger is None:
            self.logger = cast(
                TraceLogger,
                logging.getLogger(self.__module__ + '.' +
                                  self.__class__.__name__))

        self._run_lock = asyncio.Lock()
        self.cleaned_up = asyncio.Event()
        self._child_services = []

        self.loop = loop
        base_token = CancelToken(type(self).__name__, loop=loop)

        if token is None:
            self.cancel_token = base_token
        else:
            self.cancel_token = base_token.chain(token)
Exemplo n.º 9
0
async def test_cancelling_token_wait_cleans_up_chained_token_waits():
    parent = CancelToken('parent')
    child = parent.chain(CancelToken('child'))

    # this schedules both the child's `wait()` and the parent's `wait()`
    fut = asyncio.ensure_future(child.wait())
    # yield for a moment to let them spin up.
    await asyncio.sleep(0.01)

    # ensure that there are some not-done tasks
    with pytest.raises(AssertionError):
        await assert_only_current_task_not_done()

    # cancel the wait (which should also properly cancel the parent wait)
    fut.cancel()
    try:
        await fut
    except asyncio.CancelledError:
        pass
    await assert_only_current_task_not_done()
Exemplo n.º 10
0
    async def wait_first(self,
                         *awaitables: Awaitable[_TReturn],
                         token: CancelToken = None,
                         timeout: float = None) -> _TReturn:
        """
        Wait for the first awaitable to complete, unless we timeout or the token chain is triggered.

        The given token is chained with this service's token, so triggering either will cancel
        this.

        Returns the result of the first one to complete.

        Raises TimeoutError if we timeout or OperationCancelled if the token chain is triggered.

        All pending futures are cancelled before returning.
        """
        if token is None:
            token_chain = self.cancel_token
        else:
            token_chain = token.chain(self.cancel_token)
        return await token_chain.cancellable_wait(*awaitables, timeout=timeout)
Exemplo n.º 11
0
def test_token_chain_event_loop_mismatch():
    token = CancelToken('token')
    token2 = CancelToken('token2', loop=asyncio.new_event_loop())
    with pytest.raises(EventLoopMismatch):
        token.chain(token2)
Exemplo n.º 12
0
class DiscoveryProtocol(asyncio.DatagramProtocol):
    """A Kademlia-like protocol to discover RLPx nodes."""
    logger = logging.getLogger("p2p.discovery.DiscoveryProtocol")
    transport: asyncio.DatagramTransport = None
    _max_neighbours_per_packet_cache = None

    def __init__(self, privkey: datatypes.PrivateKey,
                 address: kademlia.Address,
                 bootstrap_nodes: Tuple[kademlia.Node, ...]) -> None:
        self.privkey = privkey
        self.address = address
        self.bootstrap_nodes = bootstrap_nodes
        self.this_node = kademlia.Node(self.pubkey, address)
        self.kademlia = kademlia.KademliaProtocol(self.this_node, wire=self)
        self.cancel_token = CancelToken('DiscoveryProtocol')

    async def lookup_random(self,
                            cancel_token: CancelToken) -> List[kademlia.Node]:
        return await self.kademlia.lookup_random(
            self.cancel_token.chain(cancel_token))

    def get_random_bootnode(self) -> Iterator[kademlia.Node]:
        if self.bootstrap_nodes:
            yield random.choice(self.bootstrap_nodes)
        else:
            self.logger.warning('No bootnodes available')

    def get_nodes_to_connect(self, count: int) -> Iterator[kademlia.Node]:
        return self.kademlia.routing.get_random_nodes(count)

    @property
    def pubkey(self) -> datatypes.PublicKey:
        return self.privkey.public_key

    def _get_handler(
            self,
            cmd: Command) -> Callable[[kademlia.Node, List[Any], bytes], None]:
        if cmd == CMD_PING:
            return self.recv_ping
        elif cmd == CMD_PONG:
            return self.recv_pong
        elif cmd == CMD_FIND_NODE:
            return self.recv_find_node
        elif cmd == CMD_NEIGHBOURS:
            return self.recv_neighbours
        else:
            raise ValueError("Unknwon command: {}".format(cmd))

    def _get_max_neighbours_per_packet(self) -> int:
        if self._max_neighbours_per_packet_cache is not None:
            return self._max_neighbours_per_packet_cache
        self._max_neighbours_per_packet_cache = _get_max_neighbours_per_packet(
        )
        return self._max_neighbours_per_packet_cache

    def connection_made(self, transport: asyncio.BaseTransport) -> None:
        # we need to cast here because the signature in the base class dicates BaseTransport
        # and arguments can only be redefined contravariantly
        self.transport = cast(asyncio.DatagramTransport, transport)

    async def bootstrap(self) -> None:
        self.logger.info("boostrapping with %s", self.bootstrap_nodes)
        try:
            await self.kademlia.bootstrap(self.bootstrap_nodes,
                                          self.cancel_token)
        except OperationCancelled as e:
            self.logger.info("Bootstrapping cancelled: %s", e)

    def datagram_received(self, data: Union[bytes, Text],
                          addr: Tuple[str, int]) -> None:
        ip_address, udp_port = addr
        # XXX: For now we simply discard all v5 messages. The prefix below is what geth uses to
        # identify them:
        # https://github.com/ethereum/go-ethereum/blob/c4712bf96bc1bae4a5ad4600e9719e4a74bde7d5/p2p/discv5/udp.go#L149  # noqa: E501
        if text_if_str(to_bytes, data).startswith(b"temporary discovery v5"):
            self.logger.debug("Got discovery v5 msg, discarding")
            return

        self.receive(kademlia.Address(ip_address, udp_port), cast(bytes, data))

    def error_received(self, exc: Exception) -> None:
        self.logger.error('error received: %s', exc)

    def send(self, node: kademlia.Node, message: bytes) -> None:
        self.transport.sendto(message,
                              (node.address.ip, node.address.udp_port))

    async def stop(self) -> None:
        self.logger.info('stopping discovery')
        self.cancel_token.trigger()
        self.transport.close()
        # We run lots of asyncio tasks so this is to make sure they all get a chance to execute
        # and exit cleanly when they notice the cancel token has been triggered.
        await asyncio.sleep(0.1)

    def receive(self, address: kademlia.Address, message: bytes) -> None:
        try:
            remote_pubkey, cmd_id, payload, message_hash = _unpack(message)
        except DefectiveMessage as e:
            self.logger.error('error unpacking message (%s) from %s: %s',
                              message, address, e)
            return

        # As of discovery version 4, expiration is the last element for all packets, so
        # we can validate that here, but if it changes we may have to do so on the
        # handler methods.
        expiration = rlp.sedes.big_endian_int.deserialize(payload[-1])
        if time.time() > expiration:
            self.logger.debug('received message already expired')
            return

        cmd = CMD_ID_MAP[cmd_id]
        if len(payload) != cmd.elem_count:
            self.logger.error('invalid %s payload: %s', cmd.name, payload)
            return
        node = kademlia.Node(remote_pubkey, address)
        handler = self._get_handler(cmd)
        handler(node, payload, message_hash)

    def recv_pong(self, node: kademlia.Node, payload: List[Any],
                  _: bytes) -> None:
        # The pong payload should have 3 elements: to, token, expiration
        _, token, _ = payload
        self.kademlia.recv_pong(node, token)

    def recv_neighbours(self, node: kademlia.Node, payload: List[Any],
                        _: bytes) -> None:
        # The neighbours payload should have 2 elements: nodes, expiration
        nodes, _ = payload
        self.kademlia.recv_neighbours(node, _extract_nodes_from_payload(nodes))

    def recv_ping(self, node: kademlia.Node, _: Any,
                  message_hash: bytes) -> None:
        self.kademlia.recv_ping(node, message_hash)

    def recv_find_node(self, node: kademlia.Node, payload: List[Any],
                       _: bytes) -> None:
        # The find_node payload should have 2 elements: node_id, expiration
        self.logger.debug('<<< find_node from %s', node)
        node_id, _ = payload
        self.kademlia.recv_find_node(node, big_endian_to_int(node_id))

    def send_ping(self, node: kademlia.Node) -> bytes:
        version = rlp.sedes.big_endian_int.serialize(PROTO_VERSION)
        payload = [
            version,
            self.address.to_endpoint(),
            node.address.to_endpoint()
        ]
        message = _pack(CMD_PING.id, payload, self.privkey)
        self.send(node, message)
        # Return the msg hash, which is used as a token to identify pongs.
        token = message[:MAC_SIZE]
        self.logger.debug('>>> ping %s (token == %s)', node, encode_hex(token))
        # XXX: This hack is needed because there are lots of parity 1.10 nodes out there that send
        # the wrong token on pong msgs (https://github.com/paritytech/parity/issues/8038). We
        # should get rid of this once there are no longer too many parity 1.10 nodes out there.
        parity_token = keccak(message[HEAD_SIZE + 1:])
        self.kademlia.parity_pong_tokens[parity_token] = token
        return token

    def send_find_node(self, node: kademlia.Node, target_node_id: int) -> None:
        node_id = int_to_big_endian(target_node_id).rjust(
            kademlia.k_pubkey_size // 8, b'\0')
        self.logger.debug('>>> find_node to %s', node)
        message = _pack(CMD_FIND_NODE.id, [node_id], self.privkey)
        self.send(node, message)

    def send_pong(self, node: kademlia.Node, token: bytes) -> None:
        self.logger.debug('>>> pong %s', node)
        payload = [node.address.to_endpoint(), token]
        message = _pack(CMD_PONG.id, payload, self.privkey)
        self.send(node, message)

    def send_neighbours(self, node: kademlia.Node,
                        neighbours: List[kademlia.Node]) -> None:
        nodes = []
        neighbours = sorted(neighbours)
        for n in neighbours:
            nodes.append(n.address.to_endpoint() + [n.pubkey.to_bytes()])

        max_neighbours = self._get_max_neighbours_per_packet()
        for i in range(0, len(nodes), max_neighbours):
            message = _pack(CMD_NEIGHBOURS.id, [nodes[i:i + max_neighbours]],
                            self.privkey)
            self.logger.debug('>>> neighbours to %s: %s', node,
                              neighbours[i:i + max_neighbours])
            self.send(node, message)