async def test_wait_cancel_pending_tasks_on_completion(event_loop): token = CancelToken('token') token2 = CancelToken('token2') chain = token.chain(token2) token2.trigger() await chain.wait() await assert_only_current_task_not_done()
async def test_wait_cancel_pending_tasks_on_completion(event_loop): token = CancelToken('token') token2 = CancelToken('token2') chain = token.chain(token2) token2.trigger() await chain.wait() await assert_only_current_task_not_done()
def __init__(self, token: CancelToken = None) -> None: self.finished = asyncio.Event() base_token = CancelToken(type(self).__name__) if token is None: self.cancel_token = base_token else: self.cancel_token = base_token.chain(token)
def test_token_chain_trigger_last(): token = CancelToken('token') token2 = CancelToken('token2') token3 = CancelToken('token3') chain = token.chain(token2).chain(token3) assert not chain.triggered token3.trigger() assert chain.triggered assert chain.triggered_token == token3
def test_token_chain_trigger_last(): token = CancelToken('token') token2 = CancelToken('token2') token3 = CancelToken('token3') chain = token.chain(token2).chain(token3) assert not chain.triggered token3.trigger() assert chain.triggered assert chain.triggered_token == token3
def __init__(self, token: CancelToken=None) -> None: if self.logger is None: self.logger = logging.getLogger(self.__module__ + '.' + self.__class__.__name__) self._run_lock = asyncio.Lock() self.cleaned_up = asyncio.Event() base_token = CancelToken(type(self).__name__) if token is None: self.cancel_token = base_token else: self.cancel_token = base_token.chain(token)
def __init__(self, account_db: BaseDB, root_hash: bytes, peer_pool: PeerPool, token: CancelToken = None) -> None: cancel_token = CancelToken('StateDownloader') if token is not None: cancel_token = cancel_token.chain(token) super().__init__(cancel_token) self.peer_pool = peer_pool self.root_hash = root_hash self.scheduler = StateSync(root_hash, account_db) self._running_peers = set() # type: Set[ETHPeer] self._peers_with_pending_requests = {} # type: Dict[ETHPeer, float]
def __init__(self, chain: AsyncChain, chaindb: AsyncChainDB, db: BaseDB, peer_pool: PeerPool, token: CancelToken = None) -> None: cancel_token = CancelToken('FullNodeSyncer') if token is not None: cancel_token = cancel_token.chain(token) super().__init__(cancel_token) self.chain = chain self.chaindb = chaindb self.db = db self.peer_pool = peer_pool
def __init__(self, shard: Shard, peer_pool: PeerPool, token: CancelToken) -> None: cancel_token = CancelToken("ShardSyncer") if token is not None: cancel_token = cancel_token.chain(token) super().__init__(token) self.shard = shard self.peer_pool = peer_pool self.incoming_collation_queue: asyncio.Queue[ Collation] = asyncio.Queue() self.collations_received_event = asyncio.Event() self.start_time = time.time()
async def wait_first( self, *futures: Awaitable, token: CancelToken = None, timeout: float = None) -> Any: """Wait for the first future to complete, unless we timeout or the token chain is triggered. The given token is chained with this service's token, so triggering either will cancel this. Returns the result of the first future to complete. Raises TimeoutError if we timeout or OperationCancelled if the token chain is triggered. All pending futures are cancelled before returning. """ if token is None: token_chain = self.cancel_token else: token_chain = token.chain(self.cancel_token) return await wait_with_token(*futures, token=token_chain, timeout=timeout)
def __init__(self, chaindb: AsyncChainDB, peer_pool: PeerPool, token: CancelToken = None) -> None: cancel_token = CancelToken('ChainSyncer') if token is not None: cancel_token = cancel_token.chain(token) super().__init__(cancel_token) self.chaindb = chaindb self.peer_pool = peer_pool self._running_peers = set() # type: Set[ETHPeer] self._syncing = False self._sync_complete = asyncio.Event() self._sync_requests = asyncio.Queue() # type: asyncio.Queue[ETHPeer] self._new_headers = asyncio.Queue() # type: asyncio.Queue[List[BlockHeader]] # Those are used by our msg handlers and _download_block_parts() in order to track missing # bodies/receipts for a given chain segment. self._downloaded_receipts = asyncio.Queue() # type: asyncio.Queue[List[DownloadedBlockPart]] # noqa: E501 self._downloaded_bodies = asyncio.Queue() # type: asyncio.Queue[List[DownloadedBlockPart]]
def __init__(self, token: CancelToken = None, loop: asyncio.AbstractEventLoop = None) -> None: if self.logger is None: self.logger = cast( TraceLogger, logging.getLogger(self.__module__ + '.' + self.__class__.__name__)) self._run_lock = asyncio.Lock() self.cleaned_up = asyncio.Event() self._child_services = [] self.loop = loop base_token = CancelToken(type(self).__name__, loop=loop) if token is None: self.cancel_token = base_token else: self.cancel_token = base_token.chain(token)
class BasePeer: logger = logging.getLogger("p2p.peer.Peer") conn_idle_timeout = CONN_IDLE_TIMEOUT reply_timeout = REPLY_TIMEOUT # Must be defined in subclasses. All items here must be Protocol classes representing # different versions of the same P2P sub-protocol (e.g. ETH, LES, etc). _supported_sub_protocols = [] # type: List[Type[protocol.Protocol]] # FIXME: Must be configurable. listen_port = 30303 # Will be set upon the successful completion of a P2P handshake. sub_proto = None # type: protocol.Protocol def __init__(self, remote: Node, privkey: datatypes.PrivateKey, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, aes_secret: bytes, mac_secret: bytes, egress_mac: PreImage, ingress_mac: PreImage, chaindb: AsyncChainDB, network_id: int, ) -> None: self._finished = asyncio.Event() self.remote = remote self.privkey = privkey self.reader = reader self.writer = writer self.base_protocol = P2PProtocol(self) self.chaindb = chaindb self.network_id = network_id self.sub_proto_msg_queue = asyncio.Queue() # type: asyncio.Queue[Tuple[protocol.Command, protocol._DecodedMsgType]] # noqa: E501 self.cancel_token = CancelToken('Peer') self.egress_mac = egress_mac self.ingress_mac = ingress_mac # FIXME: Yes, the encryption is insecure, see: https://github.com/ethereum/devp2p/issues/32 iv = b"\x00" * 16 aes_cipher = Cipher(algorithms.AES(aes_secret), modes.CTR(iv), default_backend()) self.aes_enc = aes_cipher.encryptor() self.aes_dec = aes_cipher.decryptor() mac_cipher = Cipher(algorithms.AES(mac_secret), modes.ECB(), default_backend()) self.mac_enc = mac_cipher.encryptor().update async def send_sub_proto_handshake(self): raise NotImplementedError() async def process_sub_proto_handshake( self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: raise NotImplementedError() async def do_sub_proto_handshake(self): """Perform the handshake for the sub-protocol agreed with the remote peer. Raises HandshakeFailure if the handshake is not successful. """ await self.send_sub_proto_handshake() cmd, msg = await self.read_msg() if isinstance(cmd, Disconnect): # Peers sometimes send a disconnect msg before they send the sub-proto handshake. raise HandshakeFailure( "{} disconnected before completing sub-proto handshake: {}".format( self, msg['reason_name'])) await self.process_sub_proto_handshake(cmd, msg) self.logger.debug("Finished %s handshake with %s", self.sub_proto, self.remote) async def do_p2p_handshake(self): """Perform the handshake for the P2P base protocol. Raises HandshakeFailure if the handshake is not successful. """ self.base_protocol.send_handshake() cmd, msg = await self.read_msg() if isinstance(cmd, Disconnect): # Peers sometimes send a disconnect msg before they send the initial P2P handshake. raise HandshakeFailure("{} disconnected before completing handshake: {}".format( self, msg['reason_name'])) self.process_p2p_handshake(cmd, msg) async def read_sub_proto_msg( self, cancel_token: CancelToken) -> Tuple[protocol.Command, protocol._DecodedMsgType]: """Read the next sub-protocol message from the queue. Raises OperationCancelled if the peer has been disconnected. """ combined_token = self.cancel_token.chain(cancel_token) return await wait_with_token(self.sub_proto_msg_queue.get(), combined_token) @property async def genesis(self) -> BlockHeader: genesis_hash = await self.chaindb.coro_lookup_block_hash(GENESIS_BLOCK_NUMBER) return await self.chaindb.coro_get_block_header_by_hash(genesis_hash) @property async def _local_chain_info(self) -> 'ChainInfo': genesis = await self.genesis head = await self.chaindb.coro_get_canonical_head() total_difficulty = await self.chaindb.coro_get_score(head.hash) return ChainInfo( block_number=head.block_number, block_hash=head.hash, total_difficulty=total_difficulty, genesis_hash=genesis.hash, ) @property def capabilities(self) -> List[Tuple[bytes, int]]: return [(klass.name, klass.version) for klass in self._supported_sub_protocols] def get_protocol_command_for(self, msg: bytes) -> protocol.Command: """Return the Command corresponding to the cmd_id encoded in the given msg.""" cmd_id = get_devp2p_cmd_id(msg) self.logger.debug("Got msg with cmd_id: %s", cmd_id) if cmd_id < self.base_protocol.cmd_length: proto = self.base_protocol elif cmd_id < self.sub_proto.cmd_id_offset + self.sub_proto.cmd_length: proto = self.sub_proto # type: ignore else: raise UnknownProtocolCommand("No protocol found for cmd_id {}".format(cmd_id)) return proto.cmd_by_id[cmd_id] async def read(self, n: int) -> bytes: self.logger.debug("Waiting for %s bytes from %s", n, self.remote) try: return await wait_with_token( self.reader.readexactly(n), self.cancel_token, timeout=self.conn_idle_timeout) except (asyncio.IncompleteReadError, ConnectionResetError): raise PeerConnectionLost("EOF reading from stream") async def run(self, finished_callback: Optional[Callable[['BasePeer'], None]] = None) -> None: try: await self.read_loop() except OperationCancelled as e: self.logger.debug("Peer finished: %s", e) except Exception: self.logger.error( "Unexpected error when handling remote msg: %s", traceback.format_exc()) finally: self._finished.set() if finished_callback is not None: finished_callback(self) def close(self): """Close this peer's reader/writer streams. This will cause the peer to stop in case it is running. If the streams have already been closed, do nothing. """ if self.reader.at_eof(): return self.reader.feed_eof() self.writer.close() async def stop(self): """Disconnect from the remote and flag this peer as finished. If the peer is already flagged as finished, do nothing. """ if self._finished.is_set(): return self.cancel_token.trigger() self.close() await self._finished.wait() self.logger.debug("Stopped %s", self) async def read_loop(self): while True: try: cmd, msg = await self.read_msg() except (PeerConnectionLost, asyncio.TimeoutError) as e: self.logger.info( "%s stopped responding (%s), disconnecting", self.remote, repr(e)) return self.process_msg(cmd, msg) async def read_msg(self) -> Tuple[protocol.Command, protocol._DecodedMsgType]: header_data = await self.read(HEADER_LEN + MAC_LEN) header = self.decrypt_header(header_data) frame_size = self.get_frame_size(header) # The frame_size specified in the header does not include the padding to 16-byte boundary, # so need to do this here to ensure we read all the frame's data. read_size = roundup_16(frame_size) frame_data = await self.read(read_size + MAC_LEN) msg = self.decrypt_body(frame_data, frame_size) cmd = self.get_protocol_command_for(msg) decoded_msg = cmd.decode(msg) self.logger.debug("Successfully decoded %s msg: %s", cmd, decoded_msg) return cmd, decoded_msg def handle_p2p_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: """Handle the base protocol (P2P) messages.""" if isinstance(cmd, Disconnect): msg = cast(Dict[str, Any], msg) self.logger.debug( "%s disconnected; reason given: %s", self, msg['reason_name']) self.close() elif isinstance(cmd, Ping): self.base_protocol.send_pong() elif isinstance(cmd, Pong): # Currently we don't do anything when we get a pong, but eventually we should # update the last time we heard from a peer in our DB (which doesn't exist yet). pass else: raise UnexpectedMessage("Unexpected msg: {} ({})".format(cmd, msg)) def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: self.sub_proto_msg_queue.put_nowait((cmd, msg)) def process_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: if isinstance(cmd.proto, P2PProtocol): self.handle_p2p_msg(cmd, msg) else: self.handle_sub_proto_msg(cmd, msg) def process_p2p_handshake(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: msg = cast(Dict[str, Any], msg) if not isinstance(cmd, Hello): self.disconnect(DisconnectReason.other) raise HandshakeFailure("Expected a Hello msg, got {}, disconnecting".format(cmd)) remote_capabilities = msg['capabilities'] self.sub_proto = self.select_sub_protocol(remote_capabilities) if self.sub_proto is None: self.disconnect(DisconnectReason.useless_peer) raise HandshakeFailure( "No matching capabilities between us ({}) and {} ({}), disconnecting".format( self.capabilities, self.remote, remote_capabilities)) self.logger.debug( "Finished P2P handshake with %s, using sub-protocol %s", self.remote, self.sub_proto) def encrypt(self, header: bytes, frame: bytes) -> bytes: if len(header) != HEADER_LEN: raise ValueError("Unexpected header length: {}".format(len(header))) header_ciphertext = self.aes_enc.update(header) mac_secret = self.egress_mac.digest()[:HEADER_LEN] self.egress_mac.update(sxor(self.mac_enc(mac_secret), header_ciphertext)) header_mac = self.egress_mac.digest()[:HEADER_LEN] frame_ciphertext = self.aes_enc.update(frame) self.egress_mac.update(frame_ciphertext) fmac_seed = self.egress_mac.digest()[:HEADER_LEN] mac_secret = self.egress_mac.digest()[:HEADER_LEN] self.egress_mac.update(sxor(self.mac_enc(mac_secret), fmac_seed)) frame_mac = self.egress_mac.digest()[:HEADER_LEN] return header_ciphertext + header_mac + frame_ciphertext + frame_mac def decrypt_header(self, data: bytes) -> bytes: if len(data) != HEADER_LEN + MAC_LEN: raise ValueError("Unexpected header length: {}".format(len(data))) header_ciphertext = data[:HEADER_LEN] header_mac = data[HEADER_LEN:] mac_secret = self.ingress_mac.digest()[:HEADER_LEN] aes = self.mac_enc(mac_secret)[:HEADER_LEN] self.ingress_mac.update(sxor(aes, header_ciphertext)) expected_header_mac = self.ingress_mac.digest()[:HEADER_LEN] if not bytes_eq(expected_header_mac, header_mac): raise AuthenticationError('Invalid header mac') return self.aes_dec.update(header_ciphertext) def decrypt_body(self, data: bytes, body_size: int) -> bytes: read_size = roundup_16(body_size) if len(data) < read_size + MAC_LEN: raise ValueError('Insufficient body length; Got {}, wanted {}'.format( len(data), (read_size + MAC_LEN))) frame_ciphertext = data[:read_size] frame_mac = data[read_size:read_size + MAC_LEN] self.ingress_mac.update(frame_ciphertext) fmac_seed = self.ingress_mac.digest()[:MAC_LEN] self.ingress_mac.update(sxor(self.mac_enc(fmac_seed), fmac_seed)) expected_frame_mac = self.ingress_mac.digest()[:MAC_LEN] if not bytes_eq(expected_frame_mac, frame_mac): raise AuthenticationError('Invalid frame mac') return self.aes_dec.update(frame_ciphertext)[:body_size] def get_frame_size(self, header: bytes) -> int: # The frame size is encoded in the header as a 3-byte int, so before we unpack we need # to prefix it with an extra byte. encoded_size = b'\x00' + header[:3] (size,) = struct.unpack(b'>I', encoded_size) return size def send(self, header: bytes, body: bytes) -> None: cmd_id = rlp.decode(body[:1], sedes=sedes.big_endian_int) self.logger.debug("Sending msg with cmd_id: %s", cmd_id) self.writer.write(self.encrypt(header, body)) def disconnect(self, reason: DisconnectReason) -> None: """Send a disconnect msg to the remote node and stop this Peer. :param reason: An item from the DisconnectReason enum. """ if not isinstance(reason, DisconnectReason): self.logger.debug("Disconnecting from remote peer; reason: %s", reason.value) raise ValueError( "Reason must be an item of DisconnectReason, got {}".format(reason)) self.base_protocol.send_disconnect(reason.value) self.close() def select_sub_protocol(self, remote_capabilities: List[Tuple[bytes, int]] ) -> protocol.Protocol: """Select the sub-protocol to use when talking to the remote. Find the highest version of our supported sub-protocols that is also supported by the remote and stores an instance of it (with the appropriate cmd_id offset) in self.sub_proto. """ matching_capabilities = set(self.capabilities).intersection(remote_capabilities) _, highest_matching_version = max(matching_capabilities, key=operator.itemgetter(1)) offset = self.base_protocol.cmd_length for proto_class in self._supported_sub_protocols: if proto_class.version == highest_matching_version: return proto_class(self, offset) return None def __str__(self): return "{} {}".format(self.__class__.__name__, self.remote)
def test_token_chain_event_loop_mismatch(): token = CancelToken('token') token2 = CancelToken('token2', loop=asyncio.new_event_loop()) with pytest.raises(EventLoopMismatch): token.chain(token2)
class ChainSyncer(PeerPoolSubscriber): logger = logging.getLogger("p2p.chain.ChainSyncer") # We'll only sync if we are connected to at least min_peers_to_sync. min_peers_to_sync = 2 # TODO: Instead of a fixed timeout, we should use a variable one that gets adjusted based on # the round-trip times from our download requests. _reply_timeout = 60 def __init__(self, chaindb: AsyncChainDB, peer_pool: PeerPool, token: CancelToken = None) -> None: self.chaindb = chaindb self.peer_pool = peer_pool self.peer_pool.subscribe(self) self.cancel_token = CancelToken('ChainSyncer') if token is not None: self.cancel_token = self.cancel_token.chain(token) self._running_peers = set() # type: Set[ETHPeer] self._syncing = False self._sync_complete = asyncio.Event() self._sync_requests = asyncio.Queue() # type: asyncio.Queue[ETHPeer] self._new_headers = asyncio.Queue( ) # type: asyncio.Queue[List[BlockHeader]] self._downloaded_receipts = asyncio.Queue( ) # type: asyncio.Queue[List[bytes]] self._downloaded_bodies = asyncio.Queue( ) # type: asyncio.Queue[List[Tuple[bytes, bytes]]] def register_peer(self, peer: BasePeer) -> None: asyncio.ensure_future(self.handle_peer(cast(ETHPeer, peer))) highest_td_peer = max( [cast(ETHPeer, peer) for peer in self.peer_pool.peers], key=operator.attrgetter('head_td')) self._sync_requests.put_nowait(highest_td_peer) async def handle_peer(self, peer: ETHPeer) -> None: """Handle the lifecycle of the given peer.""" self._running_peers.add(peer) try: await self._handle_peer(peer) finally: self._running_peers.remove(peer) async def _handle_peer(self, peer: ETHPeer) -> None: while True: try: cmd, msg = await peer.read_sub_proto_msg(self.cancel_token) except OperationCancelled: # Either our cancel token or the peer's has been triggered, so break out of the # loop. break pending_msgs = peer.sub_proto_msg_queue.qsize() if pending_msgs: self.logger.debug( "Read %s msg from %s's queue; %d msgs pending", cmd, peer, pending_msgs) # Our handle_msg() method runs cpu-intensive tasks in sub-processes so that the main # loop can keep processing msgs, and that's why we use ensure_future() instead of # awaiting for it to finish here. asyncio.ensure_future(self.handle_msg(peer, cmd, msg)) async def handle_msg(self, peer: ETHPeer, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: try: await self._handle_msg(peer, cmd, msg) except OperationCancelled: # Silently swallow OperationCancelled exceptions because we run unsupervised (i.e. # with ensure_future()). Our caller will also get an OperationCancelled anyway, and # there it will be handled. pass except Exception: self.logger.exception( "Unexpected error when processing msg from %s", peer) async def run(self) -> None: while True: peer_or_finished = await wait_with_token( self._sync_requests.get(), self._sync_complete.wait(), token=self.cancel_token) if self._sync_complete.is_set(): return # Since self._sync_complete is not set, peer_or_finished can only be a ETHPeer # instance. asyncio.ensure_future(self.sync(peer_or_finished)) async def sync(self, peer: ETHPeer) -> None: if self._syncing: self.logger.debug( "Got a NewBlock or a new peer, but already syncing so doing nothing" ) return elif len(self._running_peers) < self.min_peers_to_sync: self.logger.warn( "Connected to less peers (%d) than the minimum (%d) required to sync, " "doing nothing", len(self._running_peers), self.min_peers_to_sync) return self._syncing = True try: await self._sync(peer) except OperationCancelled: pass finally: self._syncing = False async def _sync(self, peer: ETHPeer) -> None: head = await self.chaindb.coro_get_canonical_head() head_td = await self.chaindb.coro_get_score(head.hash) if peer.head_td <= head_td: self.logger.info( "Head TD (%d) announced by %s not higher than ours (%d), not syncing", peer.head_td, peer, head_td) return self.logger.info("Starting sync with %s", peer) # FIXME: Fetch a batch of headers, in reverse order, starting from our current head, and # find the common ancestor between our chain and the peer's. start_at = max(0, head.block_number - eth.MAX_HEADERS_FETCH) while True: self.logger.info("Fetching chain segment starting at #%d", start_at) peer.sub_proto.send_get_block_headers(start_at, eth.MAX_HEADERS_FETCH, reverse=False) try: headers = await wait_with_token(self._new_headers.get(), peer.wait_until_finished(), token=self.cancel_token, timeout=self._reply_timeout) except TimeoutError: self.logger.warn( "Timeout waiting for header batch from %s, aborting sync", peer) await peer.stop() break if peer.is_finished(): self.logger.info("%s disconnected, aborting sync", peer) break self.logger.info("Got headers segment starting at #%d", start_at) # TODO: Process headers for consistency. await self._download_block_parts( [header for header in headers if not _is_body_empty(header)], self.request_bodies, self._downloaded_bodies, _body_key, 'body') self.logger.info( "Got block bodies for chain segment starting at #%d", start_at) missing_receipts = [ header for header in headers if not _is_receipts_empty(header) ] # Post-Byzantium blocks may have identical receipt roots (e.g. when they have the same # number of transactions and all succeed/failed: ropsten blocks 2503212 and 2503284), # so we do this to avoid requesting the same receipts multiple times. missing_receipts = list(unique(missing_receipts, key=_receipts_key)) await self._download_block_parts(missing_receipts, self.request_receipts, self._downloaded_receipts, _receipts_key, 'receipt') self.logger.info( "Got block receipts for chain segment starting at #%d", start_at) for header in headers: await self.chaindb.coro_persist_header(header) start_at = header.block_number + 1 self.logger.info("Imported chain segment, new head: #%d", start_at - 1) head = await self.chaindb.coro_get_canonical_head() if head.hash == peer.head_hash: self.logger.info("Chain sync with %s completed", peer) self._sync_complete.set() break async def _download_block_parts( self, headers: List[BlockHeader], request_func: Callable[[List[BlockHeader]], int], download_queue: 'Union[asyncio.Queue[List[bytes]], asyncio.Queue[List[Tuple[bytes, bytes]]]]', # noqa: E501 key_func: Callable[[BlockHeader], Union[bytes, Tuple[bytes, bytes]]], part_name: str) -> None: missing = headers.copy() # The ETH protocol doesn't guarantee that we'll get all body parts requested, so we need # to keep track of the number of pending replies and missing items to decide when to retry # them. See request_receipts() for more info. pending_replies = request_func(missing) while missing: if pending_replies == 0: pending_replies = request_func(missing) try: downloaded = await wait_with_token(download_queue.get(), token=self.cancel_token, timeout=self._reply_timeout) except TimeoutError: pending_replies = request_func(missing) continue pending_replies -= 1 unexpected = set(downloaded).difference( [key_func(header) for header in missing]) for item in unexpected: self.logger.warn("Got unexpected %s: %s", part_name, unexpected) missing = [ header for header in missing if key_func(header) not in downloaded ] def _request_block_parts( self, headers: List[BlockHeader], request_func: Callable[[ETHPeer, List[BlockHeader]], None]) -> int: length = math.ceil(len(headers) / len(self.peer_pool.peers)) batches = list(partition_all(length, headers)) for peer, batch in zip(self.peer_pool.peers, batches): request_func(cast(ETHPeer, peer), batch) return len(batches) def _send_get_block_bodies(self, peer: ETHPeer, headers: List[BlockHeader]) -> None: self.logger.debug("Requesting %d block bodies to %s", len(headers), peer) peer.sub_proto.send_get_block_bodies( [header.hash for header in headers]) def _send_get_receipts(self, peer: ETHPeer, headers: List[BlockHeader]) -> None: self.logger.debug("Requesting %d block receipts to %s", len(headers), peer) peer.sub_proto.send_get_receipts([header.hash for header in headers]) def request_bodies(self, headers: List[BlockHeader]) -> int: """Ask our peers for bodies for the given headers. See request_receipts() for details of how this is done. """ return self._request_block_parts(headers, self._send_get_block_bodies) def request_receipts(self, headers: List[BlockHeader]) -> int: """Ask our peers for receipts for the given headers. We partition the given list of headers in batches and request each to one of our connected peers. This is done because geth enforces a byte-size cap when replying to a GetReceipts msg, and we then need to re-request the items that didn't fit, so by splitting the requests across all our peers we reduce the likelyhood of having to make multiple serialized requests to ask for missing items (which happens quite frequently in practice). Returns the number of requests made. """ return self._request_block_parts(headers, self._send_get_receipts) async def wait_until_finished(self) -> None: start_at = time.time() # Wait at most 5 seconds for pending peers to finish. self.logger.info("Waiting for %d running peers to finish", len(self._running_peers)) while time.time() < start_at + 5: if not self._running_peers: break await asyncio.sleep(0.1) else: self.logger.info( "Waited too long for peers to finish, exiting anyway") async def stop(self) -> None: self.cancel_token.trigger() self.peer_pool.unsubscribe(self) await self.wait_until_finished() async def _handle_msg(self, peer: ETHPeer, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: if isinstance(cmd, eth.BlockHeaders): msg = cast(List[BlockHeader], msg) self.logger.debug("Got BlockHeaders from %d to %d", msg[0].block_number, msg[-1].block_number) self._new_headers.put_nowait(msg) elif isinstance(cmd, eth.BlockBodies): await self._handle_block_bodies(peer, cast(List[eth.BlockBody], msg)) elif isinstance(cmd, eth.Receipts): await self._handle_block_receipts( peer, cast(List[List[eth.Receipt]], msg)) elif isinstance(cmd, eth.NewBlock): msg = cast(Dict[str, Any], msg) header = msg['block'][0] actual_head = header.parent_hash actual_td = msg['total_difficulty'] - header.difficulty if actual_td > peer.head_td: peer.head_hash = actual_head peer.head_td = actual_td self._sync_requests.put_nowait(peer) elif isinstance(cmd, eth.Transactions): # TODO: Figure out what to do with those. pass elif isinstance(cmd, eth.NewBlockHashes): # TODO: Figure out what to do with those. pass else: # TODO: There are other msg types we'll want to handle here, but for now just log them # as a warning so we don't forget about it. self.logger.warn("Got unexpected msg: %s (%s)", cmd, msg) async def _handle_block_receipts( self, peer: ETHPeer, receipts: List[List[eth.Receipt]]) -> None: self.logger.debug("Got Receipts for %d blocks from %s", len(receipts), peer) loop = asyncio.get_event_loop() iterator = map(make_trie_root_and_nodes, receipts) receipts_tries = await wait_with_token(loop.run_in_executor( None, list, iterator), token=self.cancel_token) for receipt_root, trie_dict_data in receipts_tries: await self.chaindb.coro_persist_trie_data_dict(trie_dict_data) receipt_roots = [receipt_root for receipt_root, _ in receipts_tries] self._downloaded_receipts.put_nowait(receipt_roots) async def _handle_block_bodies(self, peer: ETHPeer, bodies: List[eth.BlockBody]) -> None: self.logger.debug("Got Bodies for %d blocks from %s", len(bodies), peer) loop = asyncio.get_event_loop() iterator = map(make_trie_root_and_nodes, [body.transactions for body in bodies]) transactions_tries = await wait_with_token(loop.run_in_executor( None, list, iterator), token=self.cancel_token) body_keys = [] for (body, (tx_root, trie_dict_data)) in zip(bodies, transactions_tries): await self.chaindb.coro_persist_trie_data_dict(trie_dict_data) uncles_hash = await self.chaindb.coro_persist_uncles(body.uncles) body_keys.append((tx_root, uncles_hash)) self._downloaded_bodies.put_nowait(body_keys)
class DiscoveryProtocol(asyncio.DatagramProtocol): """A Kademlia-like protocol to discover RLPx nodes.""" logger = logging.getLogger("p2p.discovery.DiscoveryProtocol") transport: asyncio.DatagramTransport = None _max_neighbours_per_packet_cache = None def __init__(self, privkey: datatypes.PrivateKey, address: kademlia.Address, bootstrap_nodes: List[str]) -> None: self.privkey = privkey self.address = address self.bootstrap_nodes = [kademlia.Node.from_uri(node) for node in bootstrap_nodes] self.this_node = kademlia.Node(self.pubkey, address) self.kademlia = kademlia.KademliaProtocol(self.this_node, wire=self) self.cancel_token = CancelToken('DiscoveryProtocol') async def lookup_random(self, cancel_token: CancelToken) -> List[kademlia.Node]: return await self.kademlia.lookup_random(self.cancel_token.chain(cancel_token)) def get_random_nodes(self, count: int) -> Generator[kademlia.Node, None, None]: return self.kademlia.routing.get_random_nodes(count) @property def pubkey(self) -> datatypes.PublicKey: return self.privkey.public_key def _get_handler(self, cmd: Command) -> Callable[[kademlia.Node, List[Any], bytes], None]: if cmd == CMD_PING: return self.recv_ping elif cmd == CMD_PONG: return self.recv_pong elif cmd == CMD_FIND_NODE: return self.recv_find_node elif cmd == CMD_NEIGHBOURS: return self.recv_neighbours else: raise ValueError("Unknwon command: {}".format(cmd)) def _get_max_neighbours_per_packet(self): if self._max_neighbours_per_packet_cache is not None: return self._max_neighbours_per_packet_cache self._max_neighbours_per_packet_cache = _get_max_neighbours_per_packet() return self._max_neighbours_per_packet_cache def connection_made(self, transport): self.transport = transport async def bootstrap(self): while self.transport is None: # FIXME: Instead of sleeping here to wait until connection_made() is called to set # .transport we should instead only call it after we know it's been set. await asyncio.sleep(1) self.logger.debug("boostrapping with %s", self.bootstrap_nodes) await self.kademlia.bootstrap(self.bootstrap_nodes, self.cancel_token) # FIXME: Enable type checking here once we have a mypy version that # includes the fix for https://github.com/python/typeshed/pull/1740 def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None: # type: ignore ip_address, udp_port = addr # XXX: For now we simply discard all v5 messages. The prefix below is what geth uses to # identify them: # https://github.com/ethereum/go-ethereum/blob/c4712bf96bc1bae4a5ad4600e9719e4a74bde7d5/p2p/discv5/udp.go#L149 # noqa: E501 if text_if_str(to_bytes, data).startswith(b"temporary discovery v5"): self.logger.debug("Got discovery v5 msg, discarding") return self.receive(kademlia.Address(ip_address, udp_port), data) # type: ignore def error_received(self, exc: Exception) -> None: self.logger.error('error received: %s', exc) def send(self, node: kademlia.Node, message: bytes) -> None: self.transport.sendto(message, (node.address.ip, node.address.udp_port)) async def stop(self): self.logger.info('stopping discovery') self.cancel_token.trigger() self.transport.close() # We run lots of asyncio tasks so this is to make sure they all get a chance to execute # and exit cleanly when they notice the cancel token has been triggered. await asyncio.sleep(1) def receive(self, address: kademlia.Address, message: bytes) -> None: try: remote_pubkey, cmd_id, payload, message_hash = _unpack(message) except DefectiveMessage as e: self.logger.error('error unpacking message (%s) from %s: %s', message, address, e) return # As of discovery version 4, expiration is the last element for all packets, so # we can validate that here, but if it changes we may have to do so on the # handler methods. expiration = rlp.sedes.big_endian_int.deserialize(payload[-1]) if time.time() > expiration: self.logger.error('received message already expired') return cmd = CMD_ID_MAP[cmd_id] if len(payload) != cmd.elem_count: self.logger.error('invalid %s payload: %s', cmd.name, payload) return node = kademlia.Node(remote_pubkey, address) handler = self._get_handler(cmd) handler(node, payload, message_hash) def recv_pong(self, node: kademlia.Node, payload: List[Any], _: bytes) -> None: # The pong payload should have 3 elements: to, token, expiration _, token, _ = payload self.kademlia.recv_pong(node, token) def recv_neighbours(self, node: kademlia.Node, payload: List[Any], _: bytes) -> None: # The neighbours payload should have 2 elements: nodes, expiration nodes, _ = payload self.kademlia.recv_neighbours(node, _extract_nodes_from_payload(nodes)) def recv_ping(self, node: kademlia.Node, _: Any, message_hash: bytes) -> None: self.kademlia.recv_ping(node, message_hash) def recv_find_node(self, node: kademlia.Node, payload: List[Any], _: bytes) -> None: # The find_node payload should have 2 elements: node_id, expiration self.logger.debug('<<< find_node from %s', node) node_id, _ = payload self.kademlia.recv_find_node(node, big_endian_to_int(node_id)) def send_ping(self, node: kademlia.Node) -> bytes: version = rlp.sedes.big_endian_int.serialize(PROTO_VERSION) payload = [version, self.address.to_endpoint(), node.address.to_endpoint()] message = _pack(CMD_PING.id, payload, self.privkey) self.send(node, message) # Return the msg hash, which is used as a token to identify pongs. token = message[:MAC_SIZE] self.logger.debug('>>> ping %s (token == %s)', node, encode_hex(token)) return token def send_find_node(self, node: kademlia.Node, target_node_id: int) -> None: node_id = int_to_big_endian( target_node_id).rjust(kademlia.k_pubkey_size // 8, b'\0') self.logger.debug('>>> find_node to %s', node) message = _pack(CMD_FIND_NODE.id, [node_id], self.privkey) self.send(node, message) def send_pong(self, node: kademlia.Node, token: bytes) -> None: self.logger.debug('>>> pong %s', node) payload = [node.address.to_endpoint(), token] message = _pack(CMD_PONG.id, payload, self.privkey) self.send(node, message) def send_neighbours(self, node: kademlia.Node, neighbours: List[kademlia.Node]) -> None: nodes = [] neighbours = sorted(neighbours) for n in neighbours: nodes.append(n.address.to_endpoint() + [n.pubkey.to_bytes()]) max_neighbours = self._get_max_neighbours_per_packet() for i in range(0, len(nodes), max_neighbours): message = _pack(CMD_NEIGHBOURS.id, [nodes[i:i + max_neighbours]], self.privkey) self.logger.debug('>>> neighbours to %s: %s', node, neighbours[i:i + max_neighbours]) self.send(node, message)
class BasePeer(metaclass=ABCMeta): logger = logging.getLogger("p2p.peer.Peer") conn_idle_timeout = CONN_IDLE_TIMEOUT reply_timeout = REPLY_TIMEOUT # Must be defined in subclasses. All items here must be Protocol classes representing # different versions of the same P2P sub-protocol (e.g. ETH, LES, etc). _supported_sub_protocols = [] # type: List[Type[protocol.Protocol]] # FIXME: Must be configurable. listen_port = 30303 # Will be set upon the successful completion of a P2P handshake. sub_proto = None # type: protocol.Protocol def __init__(self, remote: Node, privkey: datatypes.PrivateKey, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, aes_secret: bytes, mac_secret: bytes, egress_mac: sha3.keccak_256, ingress_mac: sha3.keccak_256, chaindb: AsyncChainDB, network_id: int, ) -> None: self._finished = asyncio.Event() self.remote = remote self.privkey = privkey self.reader = reader self.writer = writer self.base_protocol = P2PProtocol(self) self.chaindb = chaindb self.network_id = network_id self.sub_proto_msg_queue = asyncio.Queue() # type: asyncio.Queue[Tuple[protocol.Command, protocol._DecodedMsgType]] # noqa: E501 self.cancel_token = CancelToken('Peer') self.egress_mac = egress_mac self.ingress_mac = ingress_mac # FIXME: Yes, the encryption is insecure, see: https://github.com/ethereum/devp2p/issues/32 iv = b"\x00" * 16 aes_cipher = Cipher(algorithms.AES(aes_secret), modes.CTR(iv), default_backend()) self.aes_enc = aes_cipher.encryptor() self.aes_dec = aes_cipher.decryptor() mac_cipher = Cipher(algorithms.AES(mac_secret), modes.ECB(), default_backend()) self.mac_enc = mac_cipher.encryptor().update @abstractmethod async def send_sub_proto_handshake(self): raise NotImplementedError("Must be implemented by subclasses") @abstractmethod async def process_sub_proto_handshake( self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: raise NotImplementedError("Must be implemented by subclasses") async def do_sub_proto_handshake(self): """Perform the handshake for the sub-protocol agreed with the remote peer. Raises HandshakeFailure if the handshake is not successful. """ await self.send_sub_proto_handshake() cmd, msg = await self.read_msg() if isinstance(cmd, Disconnect): # Peers sometimes send a disconnect msg before they send the sub-proto handshake. raise HandshakeFailure( "{} disconnected before completing sub-proto handshake: {}".format( self, msg['reason_name'])) await self.process_sub_proto_handshake(cmd, msg) self.logger.debug("Finished %s handshake with %s", self.sub_proto, self.remote) async def do_p2p_handshake(self): """Perform the handshake for the P2P base protocol. Raises HandshakeFailure if the handshake is not successful. """ self.base_protocol.send_handshake() cmd, msg = await self.read_msg() if isinstance(cmd, Disconnect): # Peers sometimes send a disconnect msg before they send the initial P2P handshake. raise HandshakeFailure("{} disconnected before completing handshake: {}".format( self, msg['reason_name'])) self.process_p2p_handshake(cmd, msg) async def read_sub_proto_msg( self, cancel_token: CancelToken) -> Tuple[protocol.Command, protocol._DecodedMsgType]: """Read the next sub-protocol message from the queue. Raises OperationCancelled if the peer has been disconnected. """ combined_token = self.cancel_token.chain(cancel_token) return await wait_with_token(self.sub_proto_msg_queue.get(), token=combined_token) @property async def genesis(self) -> BlockHeader: genesis_hash = await self.chaindb.coro_lookup_block_hash(GENESIS_BLOCK_NUMBER) return await self.chaindb.coro_get_block_header_by_hash(genesis_hash) @property async def _local_chain_info(self) -> 'ChainInfo': genesis = await self.genesis head = await self.chaindb.coro_get_canonical_head() total_difficulty = await self.chaindb.coro_get_score(head.hash) return ChainInfo( block_number=head.block_number, block_hash=head.hash, total_difficulty=total_difficulty, genesis_hash=genesis.hash, ) @property def capabilities(self) -> List[Tuple[str, int]]: return [(klass.name, klass.version) for klass in self._supported_sub_protocols] def get_protocol_command_for(self, msg: bytes) -> protocol.Command: """Return the Command corresponding to the cmd_id encoded in the given msg.""" cmd_id = get_devp2p_cmd_id(msg) self.logger.debug("Got msg with cmd_id: %s", cmd_id) if cmd_id < self.base_protocol.cmd_length: proto = self.base_protocol elif cmd_id < self.sub_proto.cmd_id_offset + self.sub_proto.cmd_length: proto = self.sub_proto # type: ignore else: raise UnknownProtocolCommand("No protocol found for cmd_id {}".format(cmd_id)) return proto.cmd_by_id[cmd_id] async def read(self, n: int) -> bytes: self.logger.debug("Waiting for %s bytes from %s", n, self.remote) try: return await wait_with_token( self.reader.readexactly(n), token=self.cancel_token, timeout=self.conn_idle_timeout) except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError) as e: raise PeerConnectionLost(repr(e)) async def run(self, finished_callback: Optional[Callable[['BasePeer'], None]] = None) -> None: try: await self.read_loop() except OperationCancelled as e: self.logger.debug("Peer finished: %s", e) except Exception: self.logger.exception("Unexpected error when handling remote msg") finally: self.close() self._finished.set() if finished_callback is not None: finished_callback(self) def is_finished(self) -> bool: return self._finished.is_set() async def wait_until_finished(self) -> bool: return await self._finished.wait() def close(self): """Close this peer's reader/writer streams. This will cause the peer to stop in case it is running. If the streams have already been closed, do nothing. """ if self.reader.at_eof(): return self.reader.feed_eof() self.writer.close() async def stop(self): """Disconnect from the remote and flag this peer as finished. If the peer is already flagged as finished, do nothing. """ if self._finished.is_set(): return self.cancel_token.trigger() await self._finished.wait() self.logger.debug("Stopped %s", self) async def read_loop(self): while True: try: cmd, msg = await self.read_msg() except (PeerConnectionLost, TimeoutError) as e: self.logger.info( "%s stopped responding (%s), disconnecting", self.remote, repr(e)) return try: self.process_msg(cmd, msg) except RemoteDisconnected as e: self.logger.info("%s disconnected: %s", self, e) return async def read_msg(self) -> Tuple[protocol.Command, protocol._DecodedMsgType]: header_data = await self.read(HEADER_LEN + MAC_LEN) header = self.decrypt_header(header_data) frame_size = self.get_frame_size(header) # The frame_size specified in the header does not include the padding to 16-byte boundary, # so need to do this here to ensure we read all the frame's data. read_size = roundup_16(frame_size) frame_data = await self.read(read_size + MAC_LEN) msg = self.decrypt_body(frame_data, frame_size) cmd = self.get_protocol_command_for(msg) loop = asyncio.get_event_loop() decoded_msg = await wait_with_token( loop.run_in_executor(None, cmd.decode, msg), token=self.cancel_token) self.logger.debug("Successfully decoded %s msg: %s", cmd, decoded_msg) return cmd, decoded_msg def handle_p2p_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: """Handle the base protocol (P2P) messages.""" if isinstance(cmd, Disconnect): msg = cast(Dict[str, Any], msg) raise RemoteDisconnected(msg['reason_name']) elif isinstance(cmd, Ping): self.base_protocol.send_pong() elif isinstance(cmd, Pong): # Currently we don't do anything when we get a pong, but eventually we should # update the last time we heard from a peer in our DB (which doesn't exist yet). pass else: raise UnexpectedMessage("Unexpected msg: {} ({})".format(cmd, msg)) def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: self.sub_proto_msg_queue.put_nowait((cmd, msg)) def process_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: if cmd.is_base_protocol: self.handle_p2p_msg(cmd, msg) else: self.handle_sub_proto_msg(cmd, msg) def process_p2p_handshake(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: msg = cast(Dict[str, Any], msg) if not isinstance(cmd, Hello): self.disconnect(DisconnectReason.other) raise HandshakeFailure("Expected a Hello msg, got {}, disconnecting".format(cmd)) remote_capabilities = msg['capabilities'] try: self.sub_proto = self.select_sub_protocol(remote_capabilities) except NoMatchingPeerCapabilities: self.disconnect(DisconnectReason.useless_peer) raise HandshakeFailure( "No matching capabilities between us ({}) and {} ({}), disconnecting".format( self.capabilities, self.remote, remote_capabilities)) self.logger.debug( "Finished P2P handshake with %s, using sub-protocol %s", self.remote, self.sub_proto) def encrypt(self, header: bytes, frame: bytes) -> bytes: if len(header) != HEADER_LEN: raise ValueError("Unexpected header length: {}".format(len(header))) header_ciphertext = self.aes_enc.update(header) mac_secret = self.egress_mac.digest()[:HEADER_LEN] self.egress_mac.update(sxor(self.mac_enc(mac_secret), header_ciphertext)) header_mac = self.egress_mac.digest()[:HEADER_LEN] frame_ciphertext = self.aes_enc.update(frame) self.egress_mac.update(frame_ciphertext) fmac_seed = self.egress_mac.digest()[:HEADER_LEN] mac_secret = self.egress_mac.digest()[:HEADER_LEN] self.egress_mac.update(sxor(self.mac_enc(mac_secret), fmac_seed)) frame_mac = self.egress_mac.digest()[:HEADER_LEN] return header_ciphertext + header_mac + frame_ciphertext + frame_mac def decrypt_header(self, data: bytes) -> bytes: if len(data) != HEADER_LEN + MAC_LEN: raise ValueError("Unexpected header length: {}".format(len(data))) header_ciphertext = data[:HEADER_LEN] header_mac = data[HEADER_LEN:] mac_secret = self.ingress_mac.digest()[:HEADER_LEN] aes = self.mac_enc(mac_secret)[:HEADER_LEN] self.ingress_mac.update(sxor(aes, header_ciphertext)) expected_header_mac = self.ingress_mac.digest()[:HEADER_LEN] if not bytes_eq(expected_header_mac, header_mac): raise AuthenticationError('Invalid header mac') return self.aes_dec.update(header_ciphertext) def decrypt_body(self, data: bytes, body_size: int) -> bytes: read_size = roundup_16(body_size) if len(data) < read_size + MAC_LEN: raise ValueError('Insufficient body length; Got {}, wanted {}'.format( len(data), (read_size + MAC_LEN))) frame_ciphertext = data[:read_size] frame_mac = data[read_size:read_size + MAC_LEN] self.ingress_mac.update(frame_ciphertext) fmac_seed = self.ingress_mac.digest()[:MAC_LEN] self.ingress_mac.update(sxor(self.mac_enc(fmac_seed), fmac_seed)) expected_frame_mac = self.ingress_mac.digest()[:MAC_LEN] if not bytes_eq(expected_frame_mac, frame_mac): raise AuthenticationError('Invalid frame mac') return self.aes_dec.update(frame_ciphertext)[:body_size] def get_frame_size(self, header: bytes) -> int: # The frame size is encoded in the header as a 3-byte int, so before we unpack we need # to prefix it with an extra byte. encoded_size = b'\x00' + header[:3] (size,) = struct.unpack(b'>I', encoded_size) return size def send(self, header: bytes, body: bytes) -> None: cmd_id = rlp.decode(body[:1], sedes=sedes.big_endian_int) self.logger.debug("Sending msg with cmd_id: %s", cmd_id) self.writer.write(self.encrypt(header, body)) def disconnect(self, reason: DisconnectReason) -> None: """Send a disconnect msg to the remote node and stop this Peer. :param reason: An item from the DisconnectReason enum. """ if not isinstance(reason, DisconnectReason): self.logger.debug("Disconnecting from remote peer; reason: %s", reason.value) raise ValueError( "Reason must be an item of DisconnectReason, got {}".format(reason)) self.base_protocol.send_disconnect(reason.value) self.close() def select_sub_protocol(self, remote_capabilities: List[Tuple[bytes, int]] ) -> protocol.Protocol: """Select the sub-protocol to use when talking to the remote. Find the highest version of our supported sub-protocols that is also supported by the remote and stores an instance of it (with the appropriate cmd_id offset) in self.sub_proto. Raises NoMatchingPeerCapabilities if none of our supported protocols match one of the remote's protocols. """ matching_capabilities = set(self.capabilities).intersection(remote_capabilities) if not matching_capabilities: raise NoMatchingPeerCapabilities() _, highest_matching_version = max(matching_capabilities, key=operator.itemgetter(1)) offset = self.base_protocol.cmd_length for proto_class in self._supported_sub_protocols: if proto_class.version == highest_matching_version: return proto_class(self, offset) raise NoMatchingPeerCapabilities() def __str__(self): return "{} {}".format(self.__class__.__name__, self.remote)
class StateDownloader(PeerPoolSubscriber): logger = logging.getLogger("p2p.state.StateDownloader") _pending_nodes = {} # type: Dict[Any, float] _total_processed_nodes = 0 _report_interval = 10 # Number of seconds between progress reports. _reply_timeout = 20 # seconds _start_time = None # type: float _total_timeouts = 0 def __init__(self, state_db: BaseDB, root_hash: bytes, peer_pool: PeerPool, token: CancelToken = None) -> None: self.peer_pool = peer_pool self.peer_pool.subscribe(self) self.root_hash = root_hash self.scheduler = StateSync(root_hash, state_db) self._running_peers = set() # type: Set[ETHPeer] self._peers_with_pending_requests = {} # type: Dict[ETHPeer, float] self.cancel_token = CancelToken('StateDownloader') if token is not None: self.cancel_token = self.cancel_token.chain(token) def register_peer(self, peer: BasePeer) -> None: asyncio.ensure_future(self.handle_peer(cast(ETHPeer, peer))) @property def idle_peers(self) -> List[ETHPeer]: peers = set([cast(ETHPeer, peer) for peer in self.peer_pool.peers]) return list(peers.difference(self._peers_with_pending_requests)) async def get_idle_peer(self) -> ETHPeer: while not self.idle_peers: self.logger.debug("Waiting for an idle peer...") await wait_with_token(asyncio.sleep(0.02), token=self.cancel_token) return secrets.choice(self.idle_peers) async def handle_peer(self, peer: ETHPeer) -> None: """Handle the lifecycle of the given peer.""" self._running_peers.add(peer) try: await self._handle_peer(peer) finally: self._running_peers.remove(peer) async def _handle_peer(self, peer: ETHPeer) -> None: while True: try: cmd, msg = await peer.read_sub_proto_msg(self.cancel_token) except OperationCancelled: # Either our cancel token or the peer's has been triggered, so break out of the # loop. break # Run self._handle_msg() with ensure_future() instead of awaiting for it so that we # can keep consuming msgs while _handle_msg() performs cpu-intensive tasks in separate # processes. asyncio.ensure_future(self._handle_msg(peer, cmd, msg)) async def _handle_msg( self, peer: ETHPeer, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: loop = asyncio.get_event_loop() if isinstance(cmd, eth.NodeData): self.logger.debug("Got %d NodeData entries from %s", len(msg), peer) # Check before we remove because sometimes a reply may come after our timeout and in # that case we won't be expecting it anymore. if peer in self._peers_with_pending_requests: self._peers_with_pending_requests.pop(peer) node_keys = await loop.run_in_executor(None, list, map(keccak, msg)) for node_key, node in zip(node_keys, msg): self._total_processed_nodes += 1 try: self.scheduler.process([(node_key, node)]) except SyncRequestAlreadyProcessed: # This means we received a node more than once, which can happen when we # retry after a timeout. pass # A node may be received more than once, so pop() with a default value. self._pending_nodes.pop(node_key, None) else: # We ignore everything that is not a NodeData when doing a StateSync. self.logger.debug("Ignoring %s msg while doing a StateSync", cmd) async def stop(self): self.cancel_token.trigger() self.peer_pool.unsubscribe(self) while self._running_peers: self.logger.debug("Waiting for %d running peers to finish", len(self._running_peers)) await asyncio.sleep(0.1) async def request_nodes(self, node_keys: List[bytes]) -> None: batches = list(partition_all(eth.MAX_STATE_FETCH, node_keys)) for batch in batches: peer = await self.get_idle_peer() now = time.time() for node_key in batch: self._pending_nodes[node_key] = now self.logger.debug("Requesting %d trie nodes to %s", len(batch), peer) peer.sub_proto.send_get_node_data(batch) self._peers_with_pending_requests[peer] = now async def _periodically_retry_timedout(self): while True: now = time.time() # First, update our list of peers with pending requests by removing those for which a # request timed out. This loop mutates the dict, so we iterate on a copy of it. for peer, last_req_time in list(self._peers_with_pending_requests.items()): if now - last_req_time > self._reply_timeout: self._peers_with_pending_requests.pop(peer) # Now re-send requests for nodes that timed out. oldest_request_time = now timed_out = [] for node_key, req_time in self._pending_nodes.items(): if now - req_time > self._reply_timeout: timed_out.append(node_key) elif req_time < oldest_request_time: oldest_request_time = req_time if timed_out: self.logger.debug("Re-requesting %d trie nodes", len(timed_out)) self._total_timeouts += len(timed_out) try: await self.request_nodes(timed_out) except OperationCancelled: break # Finally, sleep until the time our oldest request is scheduled to timeout. now = time.time() sleep_duration = (oldest_request_time + self._reply_timeout) - now try: await wait_with_token(asyncio.sleep(sleep_duration), token=self.cancel_token) except OperationCancelled: break async def run(self): """Fetch all trie nodes starting from self.root_hash, and store them in self.db. Raises OperationCancelled if we're interrupted before that is completed. """ self._start_time = time.time() self.logger.info("Starting state sync for root hash %s", encode_hex(self.root_hash)) asyncio.ensure_future(self._periodically_report_progress()) asyncio.ensure_future(self._periodically_retry_timedout()) while self.scheduler.has_pending_requests: # This ensures we yield control and give _handle_msg() a chance to process any nodes # we may have received already, also ensuring we exit when our cancel token is # triggered. await wait_with_token(asyncio.sleep(0), token=self.cancel_token) requests = self.scheduler.next_batch(eth.MAX_STATE_FETCH) if not requests: # Although we frequently yield control above, to let our msg handler process # received nodes (scheduling new requests), there may be cases when the # pending nodes take a while to arrive thus causing the scheduler to run out # of new requests for a while. self.logger.info("Scheduler queue is empty, sleeping a bit") await wait_with_token(asyncio.sleep(0.5), token=self.cancel_token) continue await self.request_nodes([request.node_key for request in requests]) self.logger.info("Finished state sync with root hash %s", encode_hex(self.root_hash)) async def _periodically_report_progress(self): while True: now = time.time() self.logger.info("====== State sync progress ========") self.logger.info("Nodes processed: %d", self._total_processed_nodes) self.logger.info("Nodes processed per second (average): %d", self._total_processed_nodes / (now - self._start_time)) self.logger.info("Nodes committed to DB: %d", self.scheduler.committed_nodes) self.logger.info( "Nodes requested but not received yet: %d", len(self._pending_nodes)) self.logger.info( "Nodes scheduled but not requested yet: %d", len(self.scheduler.requests)) self.logger.info("Total nodes timed out: %d", self._total_timeouts) try: await wait_with_token(asyncio.sleep(self._report_interval), token=self.cancel_token) except OperationCancelled: break
class ShardSyncer(BaseService, PeerPoolSubscriber): logger = logging.getLogger("p2p.sharding.ShardSyncer") def __init__(self, shard: Shard, peer_pool: PeerPool, token: CancelToken) -> None: super().__init__(token) self.shard = shard self.peer_pool = peer_pool self.incoming_collation_queue: asyncio.Queue[ Collation] = asyncio.Queue() self.collations_received_event = asyncio.Event() self.cancel_token = CancelToken("ShardSyncer") if token is not None: self.cancel_token = self.cancel_token.chain(token) self.start_time = time.time() async def _run(self) -> None: self.peer_pool.subscribe(self) while True: collation = await wait_with_token( self.incoming_collation_queue.get(), token=self.cancel_token) if collation.shard_id != self.shard.shard_id: self.logger.debug( "Ignoring received collation belonging to wrong shard") continue if self.shard.get_availability( collation.header) is Availability.AVAILABLE: self.logger.debug("Ignoring already available collation") continue self.logger.debug("Adding collation {} to shard".format(collation)) self.shard.add_collation(collation) for peer in self.peer_pool.peers: cast(ShardingPeer, peer).send_collations([collation]) self.collations_received_event.set() self.collations_received_event.clear() async def _cleanup(self) -> None: self.peer_pool.unsubscribe(self) def propose(self) -> Collation: """Broadcast a new collation to the network, add it to the local shard, and return it.""" # create collation for current period period = self.get_current_period() body = zpad_right(str(self).encode("utf-8"), COLLATION_SIZE) header = CollationHeader(self.shard.shard_id, calc_chunk_root(body), period, b"\x11" * 20) collation = Collation(header, body) self.logger.debug("Proposing collation {}".format(collation)) # add collation to local chain self.shard.add_collation(collation) # broadcast collation for peer in self.peer_pool.peers: cast(ShardingPeer, peer).send_collations([collation]) return collation def register_peer(self, peer: BasePeer) -> None: asyncio.ensure_future(self.handle_peer(cast(ShardingPeer, peer))) async def handle_peer(self, peer: ShardingPeer) -> None: while True: try: collation = await wait_with_token( peer.incoming_collation_queue.get(), token=self.cancel_token) await wait_with_token( self.incoming_collation_queue.put(collation), token=self.cancel_token) except OperationCancelled: break # stop handling peer if cancel token is triggered def get_current_period(self): # TODO: get this from main chain return int((time.time() - self.start_time) // COLLATION_PERIOD)