async def multiplex(self) -> AsyncIterator[None]: """ API for running the background task that feeds individual protocol queues that allows each individual protocol to stream only its own messages. """ # We generate a new token for each time the multiplexer is used to # multiplex so that we can reliably cancel it without requiring the # master token for the multiplexer to be cancelled. async with self._multiplex_lock: multiplex_token = CancelToken( 'multiplex', loop=self.cancel_token.loop, ).chain(self.cancel_token) self._multiplex_token = multiplex_token fut = asyncio.ensure_future(self._do_multiplexing(multiplex_token)) try: yield finally: multiplex_token.trigger() del self._multiplex_token if fut.done(): fut.result() else: fut.cancel() try: await fut except asyncio.CancelledError: pass
def test_token_chain_trigger_last(): token = CancelToken('token') token2 = CancelToken('token2') token3 = CancelToken('token3') chain = token.chain(token2).chain(token3) assert not chain.triggered token3.trigger() assert chain.triggered assert chain.triggered_token == token3
def test_token_chain_trigger_middle(): token = CancelToken('token') token2 = CancelToken('token2') token3 = CancelToken('token3') intermediate_chain = token.chain(token2) chain = intermediate_chain.chain(token3) assert not chain.triggered token2.trigger() assert chain.triggered assert intermediate_chain.triggered assert chain.triggered_token == token2 assert not token3.triggered assert not token.triggered
async def test_finished_task_with_exceptions_is_raised_on_cancellation(): token = CancelToken('token') ready = asyncio.Event() async def _signal_then_raise(): ready.set() raise ValueError("raising from _signal_then_raise") # schedule in the background task = asyncio.ensure_future(token.cancellable_wait(_signal_then_raise())) # wait until the coro is running and we know it's raised an error await ready.wait() # trigger the cancel token token.trigger() with pytest.raises(ValueError, match="raising from _signal_then_raise"): await task await assert_only_current_task_not_done()
async def multiplex(self) -> AsyncIterator[None]: """ API for running the background task that feeds individual protocol queues that allows each individual protocol to stream only its own messages. """ # We generate a new token for each time the multiplexer is used to # multiplex so that we can reliably cancel it without requiring the # master token for the multiplexer to be cancelled. async with self._multiplex_lock: multiplex_token = CancelToken( 'multiplex', loop=self.cancel_token.loop, ).chain(self.cancel_token) stop = asyncio.Event() self._multiplex_token = multiplex_token fut = asyncio.ensure_future(self._do_multiplexing(stop, multiplex_token)) # wait for the multiplexing to actually start try: yield finally: # # Prevent corruption of the Transport: # # On exit the `Transport` can be in a few states: # # 1. IDLE: between reads # 2. HEADER: waiting to read the bytes for the message header # 3. BODY: already read the header, waiting for body bytes. # # In the IDLE case we get a clean shutdown by simply signaling # to `_do_multiplexing` that it should exit which is done with # an `asyncio.EVent` # # In the HEADER case we can issue a hard stop either via # cancellation or the cancel token. The read *should* be # interrupted without consuming any bytes from the # `StreamReader`. # # In the BODY case we want to give the `Transport.recv` call a # moment to finish reading the body after which it will be IDLE # and will exit via the IDLE exit mechanism. stop.set() # If the transport is waiting to read the body of the message # we want to give it a moment to finish that read. Otherwise # this leaves the transport in a corrupt state. if self._transport.read_state is TransportState.BODY: try: await asyncio.wait_for(fut, timeout=1) except asyncio.TimeoutError: pass # After giving the transport an opportunity to shutdown # cleanly, we issue a hard shutdown, first via cancellation and # then via the cancel token. This should only end up # corrupting the transport in the case where the header data is # read but the body data takes too long to arrive which should # be very rare and would likely indicate a malicious or broken # peer. if fut.done(): fut.result() else: fut.cancel() try: await fut except asyncio.CancelledError: pass multiplex_token.trigger() del self._multiplex_token
async def test_cancellable_wait_operation_cancelled(event_loop): token = CancelToken('token') token.trigger() with pytest.raises(OperationCancelled): await token.cancellable_wait(asyncio.sleep(0.02)) await assert_only_current_task_not_done()
def test_token_single(): token = CancelToken('token') assert not token.triggered token.trigger() assert token.triggered assert token.triggered_token == token
class Node(Service, Generic[TPeer]): """ Create usable nodes by adding subclasses that define the following unset attributes. """ _full_chain: FullChain = None _event_server: PeerPoolEventServer[TPeer] = None def __init__(self, event_bus: EndpointAPI, trinity_config: TrinityConfig) -> None: self.trinity_config = trinity_config self._base_db = DBClient.connect(trinity_config.database_ipc_path) self._headerdb = AsyncHeaderDB(self._base_db) self._jsonrpc_ipc_path: Path = trinity_config.jsonrpc_ipc_path self._network_id = trinity_config.network_id self.event_bus = event_bus self.master_cancel_token = CancelToken(type(self).__name__) async def handle_network_id_requests(self) -> None: async for req in self.event_bus.stream(NetworkIdRequest): # We are listening for all `NetworkIdRequest` events but we ensure to only send a # `NetworkIdResponse` to the callsite that made the request. We do that by # retrieving a `BroadcastConfig` from the request via the # `event.broadcast_config()` API. await self.event_bus.broadcast(NetworkIdResponse(self._network_id), req.broadcast_config()) _chain_config: Eth1ChainConfig = None @property def chain_config(self) -> Eth1ChainConfig: """ Convenience and caching mechanism for the `ChainConfig`. """ if self._chain_config is None: app_config = self.trinity_config.get_app_config(Eth1AppConfig) self._chain_config = app_config.get_chain_config() return self._chain_config @abstractmethod def get_chain(self) -> AsyncChainAPI: ... def get_full_chain(self) -> FullChain: if self._full_chain is None: chain_class = self.chain_config.full_chain_class self._full_chain = chain_class(self._base_db) return self._full_chain @abstractmethod def get_event_server(self) -> PeerPoolEventServer[TPeer]: """ Return the ``PeerPoolEventServer`` of the node """ ... @abstractmethod def get_peer_pool(self) -> BasePeerPool: """ Return the PeerPool instance of the node """ ... @abstractmethod def get_p2p_server(self) -> AsyncioServiceAPI: """ This is the main service that will be run, when calling :meth:`run`. It's typically responsible for syncing the chain, with peer connections. """ ... @property def base_db(self) -> AtomicDatabaseAPI: return self._base_db @property def headerdb(self) -> BaseAsyncHeaderDB: return self._headerdb async def run(self) -> None: with self._base_db: self.manager.run_daemon_task(self.handle_network_id_requests) self.manager.run_daemon_child_service( self.get_p2p_server().as_new_service()) self.manager.run_daemon_child_service(self.get_event_server()) try: await self.manager.wait_finished() finally: self.master_cancel_token.trigger() await self.event_bus.broadcast( ShutdownRequest("Node exiting. Triggering shutdown"))
class Router: logger: logging.Logger = logging.getLogger( 'pytest_asyncio_network_simulator.router.Router') def __init__(self) -> None: self.hosts: Dict[str, 'Host'] = {} self.networks: Dict[str, 'Network'] = {} self.connections: Dict[CancelToken, asyncio.Future[None]] = {} self.cancel_token = CancelToken('Router') self._run_lock = asyncio.Lock() self.cleaned_up = asyncio.Event() # # Connections API # def get_host(self, host: str) -> 'Host': from .host import Host # noqa: F811 if host not in self.hosts: self.hosts[host] = Host(host, self) return self.hosts[host] def get_network(self, name: str) -> 'Network': from .network import Network # noqa: F811 if name not in self.networks: self.networks[name] = Network(name, self) return self.networks[name] def get_connected_readers(self, address: Address) -> ReaderWriterPair: external_reader, internal_writer = direct_pipe() internal_reader, external_writer = addressed_pipe(address) token = CancelToken(str(address)).chain(self.cancel_token) connection = asyncio.ensure_future( _connect_streams( internal_reader, internal_writer, cast(AddressedTransport, external_writer.transport).queue, token, )) self.connections[token] = connection return (external_reader, external_writer) # # Run, Cancel and Cleanup API # async def run(self) -> None: """Await for the service's _run() coroutine. Once _run() returns, triggers the cancel token, call cleanup() and finished_callback (if one was passed). """ if self.is_running: raise RuntimeError( "Cannot start the service while it's already running") elif self.cancel_token.triggered: raise RuntimeError( "Cannot restart a service that has already been cancelled") try: async with self._run_lock: await self.cancel_token.wait() finally: await self.cleanup() async def cleanup(self) -> None: """ Run the ``_cleanup()`` coroutine and set the ``cleaned_up`` event after the service finishes cleanup. """ if self.connections: await asyncio.wait(self.connections.values(), timeout=2, return_when=asyncio.ALL_COMPLETED) self.cleaned_up.set() async def cancel(self) -> None: """Trigger the CancelToken and wait for the cleaned_up event to be set.""" if self.cancel_token.triggered: self.logger.warning( "Tried to cancel %s, but it was already cancelled", self) return elif not self.is_running: raise RuntimeError( "Cannot cancel a service that has not been started") self.logger.debug("Cancelling %s", self) self.cancel_token.trigger() try: await asyncio.wait_for( self.cleaned_up.wait(), timeout=5, ) except asyncio.futures.TimeoutError: self.logger.info( "Timed out waiting for %s to finish its cleanup, exiting anyway", self) else: self.logger.debug("%s finished cleanly", self) @property def is_running(self) -> bool: return self._run_lock.locked()
class DiscoveryProtocol(asyncio.DatagramProtocol): """A Kademlia-like protocol to discover RLPx nodes.""" logger = logging.getLogger("p2p.discovery.DiscoveryProtocol") transport: asyncio.DatagramTransport = None _max_neighbours_per_packet_cache = None def __init__(self, privkey: datatypes.PrivateKey, address: kademlia.Address, bootstrap_nodes: Tuple[kademlia.Node, ...]) -> None: self.privkey = privkey self.address = address self.bootstrap_nodes = bootstrap_nodes self.this_node = kademlia.Node(self.pubkey, address) self.kademlia = kademlia.KademliaProtocol(self.this_node, wire=self) self.cancel_token = CancelToken('DiscoveryProtocol') async def lookup_random(self, cancel_token: CancelToken) -> List[kademlia.Node]: return await self.kademlia.lookup_random( self.cancel_token.chain(cancel_token)) def get_random_bootnode(self) -> Iterator[kademlia.Node]: if self.bootstrap_nodes: yield random.choice(self.bootstrap_nodes) else: self.logger.warning('No bootnodes available') def get_nodes_to_connect(self, count: int) -> Iterator[kademlia.Node]: return self.kademlia.routing.get_random_nodes(count) @property def pubkey(self) -> datatypes.PublicKey: return self.privkey.public_key def _get_handler( self, cmd: Command) -> Callable[[kademlia.Node, List[Any], bytes], None]: if cmd == CMD_PING: return self.recv_ping elif cmd == CMD_PONG: return self.recv_pong elif cmd == CMD_FIND_NODE: return self.recv_find_node elif cmd == CMD_NEIGHBOURS: return self.recv_neighbours else: raise ValueError("Unknwon command: {}".format(cmd)) def _get_max_neighbours_per_packet(self) -> int: if self._max_neighbours_per_packet_cache is not None: return self._max_neighbours_per_packet_cache self._max_neighbours_per_packet_cache = _get_max_neighbours_per_packet( ) return self._max_neighbours_per_packet_cache def connection_made(self, transport: asyncio.BaseTransport) -> None: # we need to cast here because the signature in the base class dicates BaseTransport # and arguments can only be redefined contravariantly self.transport = cast(asyncio.DatagramTransport, transport) async def bootstrap(self) -> None: self.logger.info("boostrapping with %s", self.bootstrap_nodes) try: await self.kademlia.bootstrap(self.bootstrap_nodes, self.cancel_token) except OperationCancelled as e: self.logger.info("Bootstrapping cancelled: %s", e) def datagram_received(self, data: Union[bytes, Text], addr: Tuple[str, int]) -> None: ip_address, udp_port = addr # XXX: For now we simply discard all v5 messages. The prefix below is what geth uses to # identify them: # https://github.com/ethereum/go-ethereum/blob/c4712bf96bc1bae4a5ad4600e9719e4a74bde7d5/p2p/discv5/udp.go#L149 # noqa: E501 if text_if_str(to_bytes, data).startswith(b"temporary discovery v5"): self.logger.debug("Got discovery v5 msg, discarding") return self.receive(kademlia.Address(ip_address, udp_port), cast(bytes, data)) def error_received(self, exc: Exception) -> None: self.logger.error('error received: %s', exc) def send(self, node: kademlia.Node, message: bytes) -> None: self.transport.sendto(message, (node.address.ip, node.address.udp_port)) async def stop(self) -> None: self.logger.info('stopping discovery') self.cancel_token.trigger() self.transport.close() # We run lots of asyncio tasks so this is to make sure they all get a chance to execute # and exit cleanly when they notice the cancel token has been triggered. await asyncio.sleep(0.1) def receive(self, address: kademlia.Address, message: bytes) -> None: try: remote_pubkey, cmd_id, payload, message_hash = _unpack(message) except DefectiveMessage as e: self.logger.error('error unpacking message (%s) from %s: %s', message, address, e) return # As of discovery version 4, expiration is the last element for all packets, so # we can validate that here, but if it changes we may have to do so on the # handler methods. expiration = rlp.sedes.big_endian_int.deserialize(payload[-1]) if time.time() > expiration: self.logger.debug('received message already expired') return cmd = CMD_ID_MAP[cmd_id] if len(payload) != cmd.elem_count: self.logger.error('invalid %s payload: %s', cmd.name, payload) return node = kademlia.Node(remote_pubkey, address) handler = self._get_handler(cmd) handler(node, payload, message_hash) def recv_pong(self, node: kademlia.Node, payload: List[Any], _: bytes) -> None: # The pong payload should have 3 elements: to, token, expiration _, token, _ = payload self.kademlia.recv_pong(node, token) def recv_neighbours(self, node: kademlia.Node, payload: List[Any], _: bytes) -> None: # The neighbours payload should have 2 elements: nodes, expiration nodes, _ = payload self.kademlia.recv_neighbours(node, _extract_nodes_from_payload(nodes)) def recv_ping(self, node: kademlia.Node, _: Any, message_hash: bytes) -> None: self.kademlia.recv_ping(node, message_hash) def recv_find_node(self, node: kademlia.Node, payload: List[Any], _: bytes) -> None: # The find_node payload should have 2 elements: node_id, expiration self.logger.debug('<<< find_node from %s', node) node_id, _ = payload self.kademlia.recv_find_node(node, big_endian_to_int(node_id)) def send_ping(self, node: kademlia.Node) -> bytes: version = rlp.sedes.big_endian_int.serialize(PROTO_VERSION) payload = [ version, self.address.to_endpoint(), node.address.to_endpoint() ] message = _pack(CMD_PING.id, payload, self.privkey) self.send(node, message) # Return the msg hash, which is used as a token to identify pongs. token = message[:MAC_SIZE] self.logger.debug('>>> ping %s (token == %s)', node, encode_hex(token)) # XXX: This hack is needed because there are lots of parity 1.10 nodes out there that send # the wrong token on pong msgs (https://github.com/paritytech/parity/issues/8038). We # should get rid of this once there are no longer too many parity 1.10 nodes out there. parity_token = keccak(message[HEAD_SIZE + 1:]) self.kademlia.parity_pong_tokens[parity_token] = token return token def send_find_node(self, node: kademlia.Node, target_node_id: int) -> None: node_id = int_to_big_endian(target_node_id).rjust( kademlia.k_pubkey_size // 8, b'\0') self.logger.debug('>>> find_node to %s', node) message = _pack(CMD_FIND_NODE.id, [node_id], self.privkey) self.send(node, message) def send_pong(self, node: kademlia.Node, token: bytes) -> None: self.logger.debug('>>> pong %s', node) payload = [node.address.to_endpoint(), token] message = _pack(CMD_PONG.id, payload, self.privkey) self.send(node, message) def send_neighbours(self, node: kademlia.Node, neighbours: List[kademlia.Node]) -> None: nodes = [] neighbours = sorted(neighbours) for n in neighbours: nodes.append(n.address.to_endpoint() + [n.pubkey.to_bytes()]) max_neighbours = self._get_max_neighbours_per_packet() for i in range(0, len(nodes), max_neighbours): message = _pack(CMD_NEIGHBOURS.id, [nodes[i:i + max_neighbours]], self.privkey) self.logger.debug('>>> neighbours to %s: %s', node, neighbours[i:i + max_neighbours]) self.send(node, message)
class DiscoveryProtocol(asyncio.DatagramProtocol): """A Kademlia-like protocol to discover RLPx nodes.""" logger = get_extended_debug_logger("p2p.discovery.DiscoveryProtocol") transport: asyncio.DatagramTransport = None _max_neighbours_per_packet_cache = None def __init__(self, privkey: datatypes.PrivateKey, address: AddressAPI, bootstrap_nodes: Sequence[NodeAPI], cancel_token: CancelToken) -> None: self.privkey = privkey self.address = address self.bootstrap_nodes = bootstrap_nodes self.this_node = Node(self.pubkey, address) self.routing = RoutingTable(self.this_node) self.topic_table = TopicTable(self.logger) self.pong_callbacks = CallbackManager() self.ping_callbacks = CallbackManager() self.neighbours_callbacks = CallbackManager() self.topic_nodes_callbacks = CallbackManager() self.parity_pong_tokens: Dict[Hash32, Hash32] = {} self.cancel_token = CancelToken('DiscoveryProtocol').chain(cancel_token) def update_routing_table(self, node: NodeAPI) -> None: """Update the routing table entry for the given node.""" eviction_candidate = self.routing.add_node(node) if eviction_candidate: # This means we couldn't add the node because its bucket is full, so schedule a bond() # with the least recently seen node on that bucket. If the bonding fails the node will # be removed from the bucket and a new one will be picked from the bucket's # replacement cache. asyncio.ensure_future(self.bond(eviction_candidate)) async def bond(self, node: NodeAPI) -> bool: """Bond with the given node. Bonding consists of pinging the node, waiting for a pong and maybe a ping as well. It is necessary to do this at least once before we send find_node requests to a node. """ if node in self.routing: return True elif node == self.this_node: return False token = self.send_ping_v4(node) log_version = "v4" try: got_pong = await self.wait_pong_v4(node, token) except AlreadyWaitingDiscoveryResponse: self.logger.debug("bonding failed, awaiting %s pong from %s", log_version, node) return False if not got_pong: self.logger.debug("bonding failed, didn't receive %s pong from %s", log_version, node) self.routing.remove_node(node) return False try: # Give the remote node a chance to ping us before we move on and # start sending find_node requests. It is ok for wait_ping() to # timeout and return false here as that just means the remote # remembers us. await self.wait_ping(node) except AlreadyWaitingDiscoveryResponse: self.logger.debug("binding failed, already waiting for ping") return False self.logger.debug2("bonding completed successfully with %s", node) self.update_routing_table(node) return True async def wait_ping(self, remote: NodeAPI) -> bool: """Wait for a ping from the given remote. This coroutine adds a callback to ping_callbacks and yields control until that callback is called or a timeout (k_request_timeout) occurs. At that point it returns whether or not a ping was received from the given node. """ event = asyncio.Event() with self.ping_callbacks.acquire(remote, event.set): got_ping = False try: got_ping = await self.cancel_token.cancellable_wait( event.wait(), timeout=constants.KADEMLIA_REQUEST_TIMEOUT) self.logger.debug2('got expected ping from %s', remote) except asyncio.TimeoutError: self.logger.debug2('timed out waiting for ping from %s', remote) return got_ping async def wait_pong_v4(self, remote: NodeAPI, token: Hash32) -> bool: event = asyncio.Event() callback = event.set return await self._wait_pong(remote, token, event, callback) async def _wait_pong( self, remote: NodeAPI, token: Hash32, event: asyncio.Event, callback: Callable[..., Any]) -> bool: """Wait for a pong from the given remote containing the given token. This coroutine adds a callback to pong_callbacks and yields control until the given event is set or a timeout (k_request_timeout) occurs. At that point it returns whether or not a pong was received with the given pingid. """ pingid = self._mkpingid(token, remote) with self.pong_callbacks.acquire(pingid, callback): got_pong = False try: got_pong = await self.cancel_token.cancellable_wait( event.wait(), timeout=constants.KADEMLIA_REQUEST_TIMEOUT) self.logger.debug2('got expected pong with token %s', encode_hex(token)) except asyncio.TimeoutError: self.logger.debug2( 'timed out waiting for pong from %s (token == %s)', remote, encode_hex(token), ) return got_pong async def wait_neighbours(self, remote: NodeAPI) -> Tuple[NodeAPI, ...]: """Wait for a neihgbours packet from the given node. Returns the list of neighbours received. """ event = asyncio.Event() neighbours: List[NodeAPI] = [] def process(response: List[NodeAPI]) -> None: neighbours.extend(response) # This callback is expected to be called multiple times because nodes usually # split the neighbours replies into multiple packets, so we only call event.set() once # we've received enough neighbours. if len(neighbours) >= constants.KADEMLIA_BUCKET_SIZE: event.set() with self.neighbours_callbacks.acquire(remote, process): try: await self.cancel_token.cancellable_wait( event.wait(), timeout=constants.KADEMLIA_REQUEST_TIMEOUT) self.logger.debug2('got expected neighbours response from %s', remote) except asyncio.TimeoutError: self.logger.debug2( 'timed out waiting for %d neighbours from %s', constants.KADEMLIA_BUCKET_SIZE, remote, ) return tuple(n for n in neighbours if n != self.this_node) def _mkpingid(self, token: Hash32, node: NodeAPI) -> Hash32: return Hash32(token + node.pubkey.to_bytes()) def _send_find_node(self, node: NodeAPI, target_node_id: int) -> None: self.send_find_node_v4(node, target_node_id) async def lookup(self, node_id: int) -> Tuple[NodeAPI, ...]: """Lookup performs a network search for nodes close to the given target. It approaches the target by querying nodes that are closer to it on each iteration. The given target does not need to be an actual node identifier. """ nodes_asked: Set[NodeAPI] = set() nodes_seen: Set[NodeAPI] = set() async def _find_node(node_id: int, remote: NodeAPI) -> Tuple[NodeAPI, ...]: # Short-circuit in case our token has been triggered to avoid trying to send requests # over a transport that is probably closed already. self.cancel_token.raise_if_triggered() self._send_find_node(remote, node_id) candidates = await self.wait_neighbours(remote) if not candidates: self.logger.debug("got no candidates from %s, returning", remote) return tuple() all_candidates = tuple(c for c in candidates if c not in nodes_seen) candidates = tuple( c for c in all_candidates if (not self.ping_callbacks.locked(c) and not self.pong_callbacks.locked(c)) ) self.logger.debug2("got %s new candidates", len(candidates)) # Add new candidates to nodes_seen so that we don't attempt to bond with failing ones # in the future. nodes_seen.update(candidates) bonded = await asyncio.gather(*(self.bond(c) for c in candidates)) self.logger.debug2("bonded with %s candidates", bonded.count(True)) return tuple(c for c in candidates if bonded[candidates.index(c)]) def _exclude_if_asked(nodes: Iterable[NodeAPI]) -> List[NodeAPI]: nodes_to_ask = list(set(nodes).difference(nodes_asked)) return sort_by_distance(nodes_to_ask, node_id)[:constants.KADEMLIA_FIND_CONCURRENCY] closest = self.routing.neighbours(node_id) self.logger.debug("starting lookup; initial neighbours: %s", closest) nodes_to_ask = _exclude_if_asked(closest) while nodes_to_ask: self.logger.debug2("node lookup; querying %s", nodes_to_ask) nodes_asked.update(nodes_to_ask) next_find_node_queries = ( _find_node(node_id, n) for n in nodes_to_ask if not self.neighbours_callbacks.locked(n) ) results = await asyncio.gather(*next_find_node_queries) for candidates in results: closest.extend(candidates) closest = sort_by_distance(closest, node_id)[:constants.KADEMLIA_BUCKET_SIZE] nodes_to_ask = _exclude_if_asked(closest) self.logger.debug( "lookup finished for target %s; closest neighbours: %s", to_hex(node_id), closest ) return tuple(closest) async def lookup_random(self) -> Tuple[NodeAPI, ...]: return await self.lookup(random.randint(0, constants.KADEMLIA_MAX_NODE_ID)) def get_random_bootnode(self) -> Iterator[NodeAPI]: if self.bootstrap_nodes: yield random.choice(self.bootstrap_nodes) else: self.logger.warning('No bootnodes available') def get_nodes_to_connect(self, count: int) -> Iterator[NodeAPI]: return self.routing.get_random_nodes(count) @property def pubkey(self) -> datatypes.PublicKey: return self.privkey.public_key def _get_handler(self, cmd: DiscoveryCommand) -> V4_HANDLER_TYPE: if cmd == CMD_PING: return self.recv_ping_v4 elif cmd == CMD_PONG: return self.recv_pong_v4 elif cmd == CMD_FIND_NODE: return self.recv_find_node_v4 elif cmd == CMD_NEIGHBOURS: return self.recv_neighbours_v4 else: raise ValueError(f"Unknown command: {cmd}") def _get_max_neighbours_per_packet(self) -> int: if self._max_neighbours_per_packet_cache is not None: return self._max_neighbours_per_packet_cache self._max_neighbours_per_packet_cache = _get_max_neighbours_per_packet() return self._max_neighbours_per_packet_cache def connection_made(self, transport: asyncio.BaseTransport) -> None: # we need to cast here because the signature in the base class dicates BaseTransport # and arguments can only be redefined contravariantly self.transport = cast(asyncio.DatagramTransport, transport) async def bootstrap(self) -> None: for node in self.bootstrap_nodes: uri = node.uri() pubkey, _, uri_tail = uri.partition('@') pubkey_head = pubkey[:16] pubkey_tail = pubkey[-8:] self.logger.debug("full-bootnode: %s", uri) self.logger.debug("bootnode: %s...%s@%s", pubkey_head, pubkey_tail, uri_tail) try: bonding_queries = ( self.bond(n) for n in self.bootstrap_nodes if (not self.ping_callbacks.locked(n) and not self.pong_callbacks.locked(n)) ) bonded = await asyncio.gather(*bonding_queries) if not any(bonded): self.logger.info("Failed to bond with bootstrap nodes %s", self.bootstrap_nodes) return await self.lookup_random() except OperationCancelled as e: self.logger.info("Bootstrapping cancelled: %s", e) def datagram_received(self, data: Union[bytes, Text], addr: Tuple[str, int]) -> None: ip_address, udp_port = addr address = Address(ip_address, udp_port) self.receive(address, cast(bytes, data)) def send(self, node: NodeAPI, message: bytes) -> None: self.transport.sendto(message, (node.address.ip, node.address.udp_port)) async def stop(self) -> None: self.logger.info('stopping discovery') self.cancel_token.trigger() self.transport.close() # We run lots of asyncio tasks so this is to make sure they all get a chance to execute # and exit cleanly when they notice the cancel token has been triggered. await asyncio.sleep(0.1) def receive(self, address: AddressAPI, message: bytes) -> None: try: remote_pubkey, cmd_id, payload, message_hash = _unpack_v4(message) except DefectiveMessage as e: self.logger.error('error unpacking message (%s) from %s: %s', message, address, e) return # As of discovery version 4, expiration is the last element for all packets, so # we can validate that here, but if it changes we may have to do so on the # handler methods. expiration = rlp.sedes.big_endian_int.deserialize(payload[-1]) if time.time() > expiration: self.logger.debug('received message already expired') return cmd = CMD_ID_MAP[cmd_id] if len(payload) != cmd.elem_count: self.logger.error('invalid %s payload: %s', cmd.name, payload) return node = Node(remote_pubkey, address) handler = self._get_handler(cmd) handler(node, payload, message_hash) def recv_pong_v4(self, node: NodeAPI, payload: Sequence[Any], _: Hash32) -> None: # The pong payload should have 3 elements: to, token, expiration _, token, _ = payload self.logger.debug2('<<< pong (v4) from %s (token == %s)', node, encode_hex(token)) self.process_pong_v4(node, token) def recv_neighbours_v4(self, node: NodeAPI, payload: Sequence[Any], _: Hash32) -> None: # The neighbours payload should have 2 elements: nodes, expiration nodes, _ = payload neighbours = _extract_nodes_from_payload(node.address, nodes, self.logger) self.logger.debug2('<<< neighbours from %s: %s', node, neighbours) self.process_neighbours(node, neighbours) def recv_ping_v4(self, node: NodeAPI, _: Any, message_hash: Hash32) -> None: self.logger.debug2('<<< ping(v4) from %s', node) self.process_ping(node, message_hash) self.send_pong_v4(node, message_hash) def recv_find_node_v4(self, node: NodeAPI, payload: Sequence[Any], _: Hash32) -> None: # The find_node payload should have 2 elements: node_id, expiration self.logger.debug2('<<< find_node from %s', node) node_id, _ = payload if node not in self.routing: # FIXME: This is not correct; a node we've bonded before may have become unavailable # and thus removed from self.routing, but once it's back online we should accept # find_nodes from them. self.logger.debug('Ignoring find_node request from unknown node %s', node) return self.update_routing_table(node) found = self.routing.neighbours(big_endian_to_int(node_id)) self.send_neighbours_v4(node, found) def send_ping_v4(self, node: NodeAPI) -> Hash32: version = rlp.sedes.big_endian_int.serialize(PROTO_VERSION) payload = (version, self.address.to_endpoint(), node.address.to_endpoint()) message = _pack_v4(CMD_PING.id, payload, self.privkey) self.send(node, message) # Return the msg hash, which is used as a token to identify pongs. token = Hash32(message[:MAC_SIZE]) self.logger.debug2('>>> ping (v4) %s (token == %s)', node, encode_hex(token)) # XXX: This hack is needed because there are lots of parity 1.10 nodes out there that send # the wrong token on pong msgs (https://github.com/paritytech/parity/issues/8038). We # should get rid of this once there are no longer too many parity 1.10 nodes out there. parity_token = keccak(message[HEAD_SIZE + 1:]) self.parity_pong_tokens[parity_token] = token return token def send_find_node_v4(self, node: NodeAPI, target_node_id: int) -> None: node_id = int_to_big_endian( target_node_id).rjust(constants.KADEMLIA_PUBLIC_KEY_SIZE // 8, b'\0') self.logger.debug2('>>> find_node to %s', node) message = _pack_v4(CMD_FIND_NODE.id, tuple([node_id]), self.privkey) self.send(node, message) def send_pong_v4(self, node: NodeAPI, token: Hash32) -> None: self.logger.debug2('>>> pong %s', node) payload = (node.address.to_endpoint(), token) message = _pack_v4(CMD_PONG.id, payload, self.privkey) self.send(node, message) def send_neighbours_v4(self, node: NodeAPI, neighbours: List[NodeAPI]) -> None: nodes = [] neighbours = sorted(neighbours) for n in neighbours: nodes.append(n.address.to_endpoint() + [n.pubkey.to_bytes()]) max_neighbours = self._get_max_neighbours_per_packet() for i in range(0, len(nodes), max_neighbours): message = _pack_v4( CMD_NEIGHBOURS.id, tuple([nodes[i:i + max_neighbours]]), self.privkey) self.logger.debug2('>>> neighbours to %s: %s', node, neighbours[i:i + max_neighbours]) self.send(node, message) def process_neighbours(self, remote: NodeAPI, neighbours: List[NodeAPI]) -> None: """Process a neighbours response. Neighbours responses should only be received as a reply to a find_node, and that is only done as part of node lookup, so the actual processing is left to the callback from neighbours_callbacks, which is added (and removed after it's done or timed out) in wait_neighbours(). """ try: callback = self.neighbours_callbacks.get_callback(remote) except KeyError: self.logger.debug( 'unexpected neighbours from %s, probably came too late', remote) else: callback(neighbours) def process_pong_v4(self, remote: NodeAPI, token: Hash32) -> None: """Process a pong packet. Pong packets should only be received as a response to a ping, so the actual processing is left to the callback from pong_callbacks, which is added (and removed after it's done or timed out) in wait_pong(). """ # XXX: This hack is needed because there are lots of parity 1.10 nodes out there that send # the wrong token on pong msgs (https://github.com/paritytech/parity/issues/8038). We # should get rid of this once there are no longer too many parity 1.10 nodes out there. if token in self.parity_pong_tokens: # This is a pong from a buggy parity node, so need to lookup the actual token we're # expecting. token = self.parity_pong_tokens.pop(token) else: # This is a pong from a non-buggy node, so just cleanup self.parity_pong_tokens. self.parity_pong_tokens = eth_utils.toolz.valfilter( lambda val: val != token, self.parity_pong_tokens) pingid = self._mkpingid(token, remote) try: callback = self.pong_callbacks.get_callback(pingid) except KeyError: self.logger.debug('unexpected v4 pong from %s (token == %s)', remote, encode_hex(token)) else: callback() def process_ping(self, remote: NodeAPI, hash_: Hash32) -> None: """Process a received ping packet. A ping packet may come any time, unrequested, or may be prompted by us bond()ing with a new node. In the former case we'll just update the sender's entry in our routing table and reply with a pong, whereas in the latter we'll also fire a callback from ping_callbacks. """ if remote == self.this_node: self.logger.info('Invariant: received ping from this_node: %s', remote) return else: self.update_routing_table(remote) # Sometimes a ping will be sent to us as part of the bonding # performed the first time we see a node, and it is in those cases that # a callback will exist. try: callback = self.ping_callbacks.get_callback(remote) except KeyError: pass else: callback()